From 853e088ed8c1f20f079f266ec552e85f9d6f1a9f Mon Sep 17 00:00:00 2001 From: Arnav Singhvi Date: Wed, 22 Feb 2023 12:39:46 -0800 Subject: [PATCH 1/5] network transformation function & testing for PyTorch to POET --- .../network_transform_pytorch.py | 73 +++++++++++++++++++ .../pytorch/test_resnet_network_transform.py | 20 +++++ .../pytorch/test_vgg_network_transform.py | 16 ++++ 3 files changed, 109 insertions(+) create mode 100644 poet/architectures/network_transform_pytorch.py create mode 100644 poet/transformation_testing/pytorch/test_resnet_network_transform.py create mode 100644 poet/transformation_testing/pytorch/test_vgg_network_transform.py diff --git a/poet/architectures/network_transform_pytorch.py b/poet/architectures/network_transform_pytorch.py new file mode 100644 index 0000000..890fcb3 --- /dev/null +++ b/poet/architectures/network_transform_pytorch.py @@ -0,0 +1,73 @@ +from poet.power_computation import ( + LinearLayer, + ReLULayer, + Conv2dLayer, + FlattenLayer, + TanHLayer, + SigmoidLayer, + SkipAddLayer, + DropoutLayer, + GradientLayer, + InputLayer, + SkipAddLayer, + CrossEntropyLoss, + GradientLayer, + BatchNorm2d, + MaxPool2d, + AvgPool2d, + GlobalAvgPool, + get_net_costs, +) +from poet.power_computation_transformer import QueryKeyValueMatrix, QKTMatrix, QKTVMatrix +import torch.nn as nn +import torchvision.models +from torchvision.models.resnet import BasicBlock, Bottleneck + + +# transforms input model's network layers to output graph with POET layers for PyTorch models +def network_transform(net, layers, batch_size, num_classes, input_shape): + if isinstance(net, torchvision.models.resnet.ResNet): + modules = nn.Sequential(*list(net.children())) + elif isinstance(net, torchvision.models.vgg.VGG): + modules = list(net.children()) + else: + modules = net + for module in modules: + if isinstance(module, nn.Sequential): + sequential_modules = [child for child in module] + input = network_transform(sequential_modules, [layers[-1]], batch_size, num_classes, input_shape) + layers.extend(input[1:]) + if isinstance(module, BasicBlock) or isinstance(module, Bottleneck): + input = network_transform(nn.Sequential(*list(module.children())), [layers[-1]], batch_size, num_classes, input_shape) + layers.extend(input[1:]) + if isinstance(module, nn.Linear): + lin_layer = LinearLayer(module.in_features, module.out_features, layers[-1]) + act_layer = ReLULayer(lin_layer) + layers.extend([lin_layer, act_layer]) + if isinstance(module, nn.ReLU): + relu_layer = ReLULayer(layers[-1]) + layers.append(relu_layer) + if isinstance(module, nn.Conv2d): + conv_layer = Conv2dLayer( + module.in_channels, module.out_channels, module.kernel_size, module.stride[0], module.padding, layers[-1] + ) + layers.append(conv_layer) + if isinstance(module, nn.BatchNorm2d): + layers.append(BatchNorm2d(layers[-1])) + if isinstance(module, nn.MaxPool2d): + layers.append(MaxPool2d((module.kernel_size, module.kernel_size), module.stride, layers[-1])) + if isinstance(module, nn.AvgPool2d): + layers.append(AvgPool2d(module.kernel_size, module.stride, layers[-1])) + if isinstance(module, nn.Tanh): + tanh_layer = TanHLayer(layers[-1]) + layers.append(tanh_layer) + if isinstance(module, nn.Sigmoid): + sigmoid_layer = SigmoidLayer(layers[-1]) + layers.append(sigmoid_layer) + if isinstance(module, nn.Flatten): + flatten_layer = FlattenLayer(layers[-1]) + layers.append(flatten_layer) + if isinstance(module, nn.Dropout): + dropout_layer = DropoutLayer(layers[-1]) + layers.append(dropout_layer) + return layers diff --git a/poet/transformation_testing/pytorch/test_resnet_network_transform.py b/poet/transformation_testing/pytorch/test_resnet_network_transform.py new file mode 100644 index 0000000..78537f7 --- /dev/null +++ b/poet/transformation_testing/pytorch/test_resnet_network_transform.py @@ -0,0 +1,20 @@ +from poet.power_computation import InputLayer +import torchvision.models +from poet.architectures.network_transform_pytorch import network_transform + +# transforms ResNet Model network layers into POET computation layers + +batch_size = (1,) +input_shape = (3, 32, 32) +num_classes = 10 +layers = [InputLayer((batch_size, *input_shape))] + +# comment out to output transformations for different ResNet models + +# #Resnet18 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 +# final_layers = network_transform(torchvision.models.resnet18(pretrained=True), layers, batch_size, num_classes, input_shape) +# print(final_layers) + +# #Resnet50 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 +# final_layers = network_transform(torchvision.models.resnet50(pretrained=True), layers, batch_size, num_classes, input_shape) +# print(final_layers) diff --git a/poet/transformation_testing/pytorch/test_vgg_network_transform.py b/poet/transformation_testing/pytorch/test_vgg_network_transform.py new file mode 100644 index 0000000..9f753cf --- /dev/null +++ b/poet/transformation_testing/pytorch/test_vgg_network_transform.py @@ -0,0 +1,16 @@ +from poet.power_computation import InputLayer +import torch +from poet.architectures.network_transform_pytorch import network_transform + +# transforms VGG Model network layers into POET computation layers + +batch_size = (1,) +input_shape = (3, 32, 32) +num_classes = 10 +layers = [InputLayer((batch_size, *input_shape))] + +# VGG16 model transformation - https://download.pytorch.org/models/vgg16-397923af.pth +final_layers = network_transform( + torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True), layers, batch_size, num_classes, input_shape +) +print(final_layers) From 06349a8c876540cb3aa5a31f2094214eb2595ac5 Mon Sep 17 00:00:00 2001 From: Arnav Singhvi Date: Wed, 22 Feb 2023 12:46:58 -0800 Subject: [PATCH 2/5] network transformation function & testing for TensorFlow to POET --- .../network_transform_tensorflow.py | 62 ++++++++++++++ .../test_resnet_network_transform.py | 81 +++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 poet/architectures/network_transform_tensorflow.py create mode 100644 poet/transformation_testing/tensorflow/test_resnet_network_transform.py diff --git a/poet/architectures/network_transform_tensorflow.py b/poet/architectures/network_transform_tensorflow.py new file mode 100644 index 0000000..4ef9596 --- /dev/null +++ b/poet/architectures/network_transform_tensorflow.py @@ -0,0 +1,62 @@ +from poet.power_computation import ( + LinearLayer, + ReLULayer, + Conv2dLayer, + FlattenLayer, + TanHLayer, + SigmoidLayer, + SkipAddLayer, + DropoutLayer, + GradientLayer, + InputLayer, + SkipAddLayer, + CrossEntropyLoss, + GradientLayer, + BatchNorm2d, + MaxPool2d, + AvgPool2d, + GlobalAvgPool, + get_net_costs, +) +from poet.power_computation_transformer import QueryKeyValueMatrix, QKTMatrix, QKTVMatrix +from torchvision.models.resnet import BasicBlock, Bottleneck +import tensorflow as tf + + +##transforms input model's network layers to output graph with POET layers for TensorFlow models +def network_transform(net, layers, batch_size, num_classes, input_shape): + for module in net: + if isinstance(module, tf.keras.layers.Dense): + lin_layer = LinearLayer(module.units, module.units, layers[-1]) + act_layer = ReLULayer(lin_layer) + layers.extend([lin_layer, act_layer]) + if isinstance(module, tf.keras.layers.Activation) and module._name == "relu": + relu_layer = ReLULayer(layers[-1]) + layers.append(relu_layer) + if isinstance(module, tf.keras.layers.Conv2D): + if module.padding == "valid": + padding = (0, 0) + elif module.padding == "same": + padding = (1, 1) + conv_layer = Conv2dLayer(1, module.filters, module.kernel_size, module.strides[0], padding, layers[-1]) + layers.append(conv_layer) + if isinstance(module, tf.keras.layers.BatchNormalization): + layers.append(BatchNorm2d(layers[-1])) + if isinstance(module, tf.keras.layers.MaxPool2D): + layers.append(MaxPool2d(module.pool_size, module.strides[0], layers[-1])) + if isinstance(module, tf.keras.layers.GlobalAveragePooling2D): + if module.keepdims: + layers.append(GlobalAvgPool(layers[-1])) + if isinstance(module, tf.keras.layers.Activation) and module._name == "tanh": + tanh_layer = TanHLayer(layers[-1]) + layers.append(tanh_layer) + if isinstance(module, tf.keras.layers.Activation) and module._name == "sigmoid": + sigmoid_layer = SigmoidLayer(layers[-1]) + layers.append(sigmoid_layer) + if isinstance(module, tf.keras.layers.Flatten): + flatten_layer = FlattenLayer(layers[-1]) + layers.append(flatten_layer) + if isinstance(module, tf.keras.layers.Dropout): + dropout_layer = DropoutLayer(layers[-1]) + layers.append(dropout_layer) + return layers diff --git a/poet/transformation_testing/tensorflow/test_resnet_network_transform.py b/poet/transformation_testing/tensorflow/test_resnet_network_transform.py new file mode 100644 index 0000000..501e75b --- /dev/null +++ b/poet/transformation_testing/tensorflow/test_resnet_network_transform.py @@ -0,0 +1,81 @@ +from poet.architectures.network_transform_tensorflow import network_transform +from poet.power_computation import InputLayer +import tensorflow as tf + +# transforms VGG Model network layers into POET computation layers + + +# defining ResNet model in TensorFlow +def BasicBlock(inputs, num_channels, kernel_size, num_blocks, skip_blocks, name): + """Basic residual block""" + x = inputs + for i in range(num_blocks): + if i not in skip_blocks: + x1 = ConvNormRelu(x, num_channels, kernel_size, strides=[1, 1], name=name + "." + str(i)) + x = tf.keras.layers.Add()([x, x1]) + x = tf.keras.layers.Activation("relu")(x) + return x + + +def BasicBlockDown(inputs, num_channels, kernel_size, name): + """Residual block with strided downsampling""" + x = inputs + x1 = ConvNormRelu(x, num_channels, kernel_size, strides=[2, 1], name=name + ".0") + x = tf.keras.layers.Conv2D( + num_channels, kernel_size=1, strides=2, padding="same", activation="linear", use_bias=False, name=name + ".0.downsample.0" + )(x) + x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".0.downsample.1")(x) + x = tf.keras.layers.Add()([x, x1]) + x = tf.keras.layers.Activation("relu")(x) + return x + + +def ConvNormRelu(x, num_channels, kernel_size, strides, name): + """Layer consisting of 2 consecutive batch normalizations with 1 first relu""" + if strides[0] == 2: + x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name=name + ".pad")(x) + x = tf.keras.layers.Conv2D( + num_channels, kernel_size, strides[0], padding="valid", activation="linear", use_bias=False, name=name + ".conv1" + )(x) + else: + x = tf.keras.layers.Conv2D( + num_channels, kernel_size, strides[0], padding="same", activation="linear", use_bias=False, name=name + ".conv1" + )(x) + x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".bn1")(x) + x = tf.keras.layers.Activation("relu")(x) + x = tf.keras.layers.Conv2D( + num_channels, kernel_size, strides[1], padding="same", activation="linear", use_bias=False, name=name + ".conv2" + )(x) + x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".bn2")(x) + return x + + +def ResNet18(inputs): + x = tf.keras.layers.ZeroPadding2D(padding=(3, 3), name="pad")(inputs) + x = tf.keras.layers.Conv2D(filters=64, kernel_size=7, strides=2, padding="valid", activation="linear", use_bias=False, name="conv1")(x) + x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name="bn1")(x) + x = tf.keras.layers.Activation("relu", name="relu")(x) + x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name="pad1")(x) + x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="maxpool")(x) + x = BasicBlock(x, num_channels=64, kernel_size=3, num_blocks=2, skip_blocks=[], name="layer1") + x = BasicBlockDown(x, num_channels=128, kernel_size=3, name="layer2") + x = BasicBlock(x, num_channels=128, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer2") + x = BasicBlockDown(x, num_channels=256, kernel_size=3, name="layer3") + x = BasicBlock(x, num_channels=256, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer3") + x = BasicBlockDown(x, num_channels=512, kernel_size=3, name="layer4") + x = BasicBlock(x, num_channels=512, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer4") + x = tf.keras.layers.GlobalAveragePooling2D(name="avgpool")(x) + x = tf.keras.layers.Dense(units=1000, use_bias=True, activation="linear", name="fc")(x) + return x + + +inputs = tf.keras.Input((None, None, 3)) +resnet_tf = ResNet18(inputs) +model = tf.keras.Model(inputs, resnet_tf) +batch_size = (1,) +input_shape = (3, 32, 32) +num_classes = 10 +layers = [InputLayer((batch_size, *input_shape))] + +final_layers = network_transform([layer for layer in model.layers], layers, batch_size, num_classes, input_shape) +print(final_layers) From 305f86156ddc0eec21b14ad9bf073e59263e5d8b Mon Sep 17 00:00:00 2001 From: Arnav Singhvi Date: Wed, 22 Feb 2023 12:51:00 -0800 Subject: [PATCH 3/5] graph transformation function for converting PyTorch model graph to POET layers graph --- poet/architectures/graph_transformation.py | 59 +++++++++++++++++++ .../pytorch/test_resnet_graph_transform.py | 19 ++++++ 2 files changed, 78 insertions(+) create mode 100644 poet/architectures/graph_transformation.py create mode 100644 poet/transformation_testing/pytorch/test_resnet_graph_transform.py diff --git a/poet/architectures/graph_transformation.py b/poet/architectures/graph_transformation.py new file mode 100644 index 0000000..f6daad2 --- /dev/null +++ b/poet/architectures/graph_transformation.py @@ -0,0 +1,59 @@ +from poet.power_computation import ( + LinearLayer, + ReLULayer, + Conv2dLayer, + FlattenLayer, + TanHLayer, + SigmoidLayer, + SkipAddLayer, + DropoutLayer, + GradientLayer, + InputLayer, + SkipAddLayer, + CrossEntropyLoss, + GradientLayer, + BatchNorm2d, + MaxPool2d, + AvgPool2d, + GlobalAvgPool, + get_net_costs, +) + + +# transforms input model's graph to output graph with POET layer nodes +def graph_transform(traced): + for n in traced.graph.nodes: + if " Date: Thu, 23 Feb 2023 09:57:26 -0800 Subject: [PATCH 4/5] implemented rematerialization and paging for PyTorch models, currently only remat completes training --- poet/architectures/remat_and_paging.py | 44 ++++++ .../pytorch/test_resnet_model_transform.py | 129 ++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 poet/architectures/remat_and_paging.py create mode 100644 poet/transformation_testing/pytorch/test_resnet_model_transform.py diff --git a/poet/architectures/remat_and_paging.py b/poet/architectures/remat_and_paging.py new file mode 100644 index 0000000..72b3e5d --- /dev/null +++ b/poet/architectures/remat_and_paging.py @@ -0,0 +1,44 @@ +import torch.nn as nn + + +# output layer traversal path for given layer in model +def get_all_parent_layers(net, type): + layers = [] + for name in net.named_modules(): + if name == type: + layer = net + attributes = name.strip().split(".") + for attr in attributes: + if not attr.isnumeric(): + layer = getattr(layer, attr) + else: + layer = layer[int(attr)] + layers.append([layer, attributes[-1]]) + return layers + + +# implements rematerializaion and paging techniques in inputted model +# remat - sets inputted node to Identity layer and saves node and arguments +# paging - page out layer to cpu +def memory_saving(model_indexer, node, is_remat, is_page, remat_list): + if is_remat: + remat_list.append([node.target, getattr(model_indexer[0], model_indexer[1])]) + setattr(model_indexer[0], model_indexer[1], nn.Identity()) + elif is_page: + layer = getattr(model_indexer[0], model_indexer[1]) + layer = layer.cpu() + return + + +# remat - recomputes inputted node and sets layer back in model +# paging - page in layer to gpu +def reuse_layer(model_indexer, node, is_remat, is_page, remat_list): + if is_remat: + for layer in remat_list: + if layer[0] == node.target: + break + setattr(model_indexer[0], model_indexer[1], layer[1]) + elif is_page: + layer = getattr(model_indexer[0], model_indexer[1]) + layer = layer.gpu() + return diff --git a/poet/transformation_testing/pytorch/test_resnet_model_transform.py b/poet/transformation_testing/pytorch/test_resnet_model_transform.py new file mode 100644 index 0000000..60117c6 --- /dev/null +++ b/poet/transformation_testing/pytorch/test_resnet_model_transform.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from torch.fx import symbolic_trace +import poet.architectures.remat_and_paging as remat_and_paging +import copy + +# Device configuration +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Hyper-parameters +num_epochs = 150 +learning_rate = 0.01 + +# Image preprocessing modules +transform = transforms.Compose([transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()]) + +# CIFAR-10 dataset +train_dataset = torchvision.datasets.CIFAR10(root="../../data/", train=True, transform=transform, download=True) + +test_dataset = torchvision.datasets.CIFAR10(root="../../data/", train=False, transform=transforms.ToTensor()) + +# Data loader +train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True) + +test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False) + +# Load in mode +model = torchvision.models.resnet18(pretrained=True) +if torch.cuda.is_available(): + model.cuda() + +# convert model to graph +traced = symbolic_trace(model) +torchcopy = copy.deepcopy(traced) + +# layers to be rematerialized and paged +remat_layers = ["layer1.1.conv1", "layer1.0.bn1"] +paging_layers = ["layer2.1.conv1", "layer2.0.bn1"] + +# intermediate storage of remat layers to be recomputed later +remat_list = [] + +# Loss and optimizer +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) + + +# For updating learning rate +def update_lr(optimizer, lr): + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +# Train the model +total_step = len(train_loader) +curr_lr = learning_rate +for epoch in range(num_epochs): + correct = 0 + total = 0 + for i, (images, labels) in enumerate(train_loader): + images = images.to(device) + labels = labels.to(device) + + # Forward pass - includes removing layer temporarily (remat) and paging out layer to cpu + for n in torchcopy.graph.nodes: + if n.target in remat_layers: + model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] + layer = getattr(model_indexer[0], model_indexer[1]) + layer.register_forward_hook(remat_and_paging.memory_saving(model_indexer, n, True, False, remat_list)) + # if n.target in paging_layers: + # model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] + # layer = getattr(model_indexer[0], model_indexer[1]) + # layer.register_forward_hook(remat_and_paging.memory_saving(model_indexer, n, False, True, remat_list)) + outputs = model(images) + loss = criterion(outputs, labels) + + # Backward and optimize + for n in torchcopy.graph.nodes: + if n.target in remat_layers: + model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] + layer = getattr(model_indexer[0], model_indexer[1]) + layer.register_backward_hook(remat_and_paging.reuse_layer(model_indexer, n, True, False, remat_list)) + # if n.target in paging_layers: + # model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] + # layer = getattr(model_indexer[0], model_indexer[1]) + # layer.register_backward_hook(remat_and_paging.reuse_layer(model_indexer, n, layer, False, True, remat_list)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (i + 1) % 100 == 0: + print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}".format(epoch + 1, num_epochs, i + 1, total_step, loss.item())) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + print("Accuracy of the model on the train images: {} %".format(100 * correct / total)) + # Decay learning rate + if (epoch + 1) % 20 == 0: + curr_lr /= 3 + update_lr(optimizer, curr_lr) + +# Test the model +model.eval() +with torch.no_grad(): + correct = 0 + total = 0 + for images, labels in test_loader: + images = images.to(device) + labels = labels.to(device) + for n in torchcopy.graph.nodes: + if n.target in remat_layers: + model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] + layer = getattr(model_indexer[0], model_indexer[1]) + layer.register_forward_hook(remat_and_paging.memory_saving(model_indexer, n, remat_list)) + # if n.target in paging_layers: + # model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] + # layer = getattr(model_indexer[0], model_indexer[1]) + # layer.register_forward_hook(remat_and_paging.memory_saving(model_indexer, n, False, True, remat_list)) + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print("Accuracy of the model on the test images: {} %".format(100 * correct / total)) + +# Save the model checkpoint +torch.save(model.state_dict(), "resnet.ckpt") From 906aab670ad70eca0bad1b40876735a51211d266 Mon Sep 17 00:00:00 2001 From: Arnav Singhvi Date: Tue, 4 Apr 2023 10:26:40 -0700 Subject: [PATCH 5/5] added comprehensive comments for remat functions --- poet/architectures/remat_and_paging.py | 45 ++++++++++--------- poet/poet_solver.py | 2 +- poet/poet_solver_gurobi.py | 3 +- .../pytorch/test_resnet_model_transform.py | 16 +------ poet/utils/checkmate/core/utils/graph.py | 2 +- 5 files changed, 29 insertions(+), 39 deletions(-) diff --git a/poet/architectures/remat_and_paging.py b/poet/architectures/remat_and_paging.py index 72b3e5d..76823d5 100644 --- a/poet/architectures/remat_and_paging.py +++ b/poet/architectures/remat_and_paging.py @@ -1,44 +1,45 @@ import torch.nn as nn -# output layer traversal path for given layer in model +# output layer traversal path in model for passed-in layer def get_all_parent_layers(net, type): layers = [] + # iterates over all layers in model for name in net.named_modules(): + # check if curent module name matches specified type if name == type: layer = net + # extracts layer type from its name in the model which contains indices into sub-modules attributes = name.strip().split(".") for attr in attributes: if not attr.isnumeric(): + # retrieve layer attribute layer = getattr(layer, attr) else: + # index into sub-modules list traversing model's path of layers layer = layer[int(attr)] + # append list of final layer and attribute name layers.append([layer, attributes[-1]]) return layers -# implements rematerializaion and paging techniques in inputted model -# remat - sets inputted node to Identity layer and saves node and arguments -# paging - page out layer to cpu -def memory_saving(model_indexer, node, is_remat, is_page, remat_list): - if is_remat: - remat_list.append([node.target, getattr(model_indexer[0], model_indexer[1])]) - setattr(model_indexer[0], model_indexer[1], nn.Identity()) - elif is_page: - layer = getattr(model_indexer[0], model_indexer[1]) - layer = layer.cpu() +# implements rematerializaion technique on inputted model +# which saves passed-in node during forward pass for later recomputation +# during the backward pass of model +def memory_saving(model_indexer, node, remat_list): + # saves node and arguments for later recomputation + remat_list.append([node.target, getattr(model_indexer[0], model_indexer[1])]) + # sets inputted node to Identity layer + setattr(model_indexer[0], model_indexer[1], nn.Identity()) return -# remat - recomputes inputted node and sets layer back in model -# paging - page in layer to gpu -def reuse_layer(model_indexer, node, is_remat, is_page, remat_list): - if is_remat: - for layer in remat_list: - if layer[0] == node.target: - break - setattr(model_indexer[0], model_indexer[1], layer[1]) - elif is_page: - layer = getattr(model_indexer[0], model_indexer[1]) - layer = layer.gpu() +# recomputes inputted node which was rematerialized and sets layer back into model +def reuse_layer(model_indexer, node, remat_list): + # iterates over rematerialized nodes to find matching layer + for layer in remat_list: + if layer[0] == node.target: + break + # sets inputted node back to its original state + setattr(model_indexer[0], model_indexer[1], layer[1]) return diff --git a/poet/poet_solver.py b/poet/poet_solver.py index 8cff292..15976f4 100644 --- a/poet/poet_solver.py +++ b/poet/poet_solver.py @@ -107,7 +107,7 @@ def _initialize_variables(self): def _create_correctness_constraints(self): # ensure all computations are possible - for (u, v) in self.g.edge_list: + for u, v in self.g.edge_list: for t in range(self.T): self.m += self.R[t][v] <= self.R[t][u] + self.SRam[t][u] # ensure all checkpoints are in memory diff --git a/poet/poet_solver_gurobi.py b/poet/poet_solver_gurobi.py index cb25b29..e98da9b 100644 --- a/poet/poet_solver_gurobi.py +++ b/poet/poet_solver_gurobi.py @@ -13,6 +13,7 @@ # noinspection PyPackageRequirements + # POET ILP defined using Gurobi class POETSolverGurobi: def __init__( @@ -124,7 +125,7 @@ def _disable_paging(self): def _create_correctness_constraints(self): # ensure all computations are possible - for (u, v) in self.g.edge_list: + for u, v in self.g.edge_list: for t in range(self.T): self.m.addLConstr(self.R[t, v], GRB.LESS_EQUAL, self.R[t, u] + self.SRam[t, u]) # ensure all checkpoints are in memory diff --git a/poet/transformation_testing/pytorch/test_resnet_model_transform.py b/poet/transformation_testing/pytorch/test_resnet_model_transform.py index 60117c6..9e20116 100644 --- a/poet/transformation_testing/pytorch/test_resnet_model_transform.py +++ b/poet/transformation_testing/pytorch/test_resnet_model_transform.py @@ -68,11 +68,7 @@ def update_lr(optimizer, lr): if n.target in remat_layers: model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] layer = getattr(model_indexer[0], model_indexer[1]) - layer.register_forward_hook(remat_and_paging.memory_saving(model_indexer, n, True, False, remat_list)) - # if n.target in paging_layers: - # model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] - # layer = getattr(model_indexer[0], model_indexer[1]) - # layer.register_forward_hook(remat_and_paging.memory_saving(model_indexer, n, False, True, remat_list)) + layer.register_forward_hook(remat_and_paging.memory_saving(model_indexer, n, remat_list)) outputs = model(images) loss = criterion(outputs, labels) @@ -81,11 +77,7 @@ def update_lr(optimizer, lr): if n.target in remat_layers: model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] layer = getattr(model_indexer[0], model_indexer[1]) - layer.register_backward_hook(remat_and_paging.reuse_layer(model_indexer, n, True, False, remat_list)) - # if n.target in paging_layers: - # model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] - # layer = getattr(model_indexer[0], model_indexer[1]) - # layer.register_backward_hook(remat_and_paging.reuse_layer(model_indexer, n, layer, False, True, remat_list)) + layer.register_backward_hook(remat_and_paging.reuse_layer(model_indexer, n, remat_list)) optimizer.zero_grad() loss.backward() optimizer.step() @@ -114,10 +106,6 @@ def update_lr(optimizer, lr): model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] layer = getattr(model_indexer[0], model_indexer[1]) layer.register_forward_hook(remat_and_paging.memory_saving(model_indexer, n, remat_list)) - # if n.target in paging_layers: - # model_indexer = remat_and_paging.get_all_parent_layers(model, n.target)[0] - # layer = getattr(model_indexer[0], model_indexer[1]) - # layer.register_forward_hook(remat_and_paging.memory_saving(model_indexer, n, False, True, remat_list)) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) diff --git a/poet/utils/checkmate/core/utils/graph.py b/poet/utils/checkmate/core/utils/graph.py index 539baea..0a3e084 100644 --- a/poet/utils/checkmate/core/utils/graph.py +++ b/poet/utils/checkmate/core/utils/graph.py @@ -10,7 +10,7 @@ def edge_to_adj_list(E: EdgeList, convert_undirected=False): """Returns an (undirected / bidirectional) adjacency list""" adj_list = defaultdict(set) - for (i, j) in E: + for i, j in E: adj_list[i].add(j) if convert_undirected: adj_list[j].add(i)