Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transforming Models from PyTorch/Tensorflow to POET #9

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions poet/architectures/graph_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
from poet.power_computation import (
LinearLayer,
ReLULayer,
Conv2dLayer,
FlattenLayer,
TanHLayer,
SigmoidLayer,
SkipAddLayer,
DropoutLayer,
GradientLayer,
InputLayer,
SkipAddLayer,
CrossEntropyLoss,
GradientLayer,
BatchNorm2d,
MaxPool2d,
AvgPool2d,
GlobalAvgPool,
get_net_costs,
)


# transforms input model's graph to output graph with POET layer nodes
def graph_transform(traced: torch.fx.graph_module.GraphModule) -> torch.fx.graph_module.GraphModule:
for n in traced.graph.nodes:
# ignores built-in functions and input x which are not layer nodes in the model graph
# ignores poet layer nodes which are added to the model graph from this function
if "<built-in function" in str(n.target) or "poet" in str(n.target) or "x" == str(n.target):
continue
elif "fc" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(LinearLayer, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
elif "flatten" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(FlattenLayer, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
elif "relu" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(ReLULayer, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
elif "conv" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(Conv2dLayer, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
elif "bn" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(BatchNorm2d, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
elif "maxpool" in str(n.target):
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(MaxPool2d, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
traced.graph.erase_node(n)
else:
user_input = input(str(n.target) + " is not supported by POET layers. Would you like to proceed? (y/n)")
if user_input.lower() == "y":
continue
else:
exit(0)
traced.recompile()
return traced
72 changes: 72 additions & 0 deletions poet/architectures/network_transform_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from poet.power_computation import (
LinearLayer,
ReLULayer,
Conv2dLayer,
FlattenLayer,
TanHLayer,
SigmoidLayer,
SkipAddLayer,
DropoutLayer,
GradientLayer,
InputLayer,
SkipAddLayer,
CrossEntropyLoss,
GradientLayer,
BatchNorm2d,
MaxPool2d,
AvgPool2d,
GlobalAvgPool,
get_net_costs,
)
import torch.nn as nn
import torchvision.models
from torchvision.models.resnet import BasicBlock, Bottleneck
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We say network transform, but this is mostly for Resnet+common layers? I think we should have a better way to organize this. I don't know exactly how - open to suggestions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the goal was to extend this for other models like BERT which I didn't fully support in my implementation yet. I can comment out this import for now.



# transforms input model's network layers to output graph with POET layers for PyTorch models
def network_transform(net, layers, batch_size, num_classes, input_shape):
if isinstance(net, torchvision.models.resnet.ResNet):
modules = nn.Sequential(*list(net.children()))
elif isinstance(net, torchvision.models.vgg.VGG):
modules = list(net.children())
else:
modules = net
for module in modules:
if isinstance(module, nn.Sequential):
sequential_modules = [child for child in module]
input = network_transform(sequential_modules, [layers[-1]], batch_size, num_classes, input_shape)
layers.extend(input[1:])
if isinstance(module, BasicBlock) or isinstance(module, Bottleneck):
input = network_transform(nn.Sequential(*list(module.children())), [layers[-1]], batch_size, num_classes, input_shape)
layers.extend(input[1:])
if isinstance(module, nn.Linear):
lin_layer = LinearLayer(module.in_features, module.out_features, layers[-1])
act_layer = ReLULayer(lin_layer)
layers.extend([lin_layer, act_layer])
if isinstance(module, nn.ReLU):
relu_layer = ReLULayer(layers[-1])
layers.append(relu_layer)
if isinstance(module, nn.Conv2d):
conv_layer = Conv2dLayer(
module.in_channels, module.out_channels, module.kernel_size, module.stride[0], module.padding, layers[-1]
)
layers.append(conv_layer)
if isinstance(module, nn.BatchNorm2d):
layers.append(BatchNorm2d(layers[-1]))
if isinstance(module, nn.MaxPool2d):
layers.append(MaxPool2d((module.kernel_size, module.kernel_size), module.stride, layers[-1]))
if isinstance(module, nn.AvgPool2d):
layers.append(AvgPool2d(module.kernel_size, module.stride, layers[-1]))
if isinstance(module, nn.Tanh):
tanh_layer = TanHLayer(layers[-1])
layers.append(tanh_layer)
if isinstance(module, nn.Sigmoid):
sigmoid_layer = SigmoidLayer(layers[-1])
layers.append(sigmoid_layer)
if isinstance(module, nn.Flatten):
flatten_layer = FlattenLayer(layers[-1])
layers.append(flatten_layer)
if isinstance(module, nn.Dropout):
dropout_layer = DropoutLayer(layers[-1])
layers.append(dropout_layer)
return layers
62 changes: 62 additions & 0 deletions poet/architectures/network_transform_tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from poet.power_computation import (
LinearLayer,
ReLULayer,
Conv2dLayer,
FlattenLayer,
TanHLayer,
SigmoidLayer,
SkipAddLayer,
DropoutLayer,
GradientLayer,
InputLayer,
SkipAddLayer,
CrossEntropyLoss,
GradientLayer,
BatchNorm2d,
MaxPool2d,
AvgPool2d,
GlobalAvgPool,
get_net_costs,
)
from poet.power_computation_transformer import QueryKeyValueMatrix, QKTMatrix, QKTVMatrix
from torchvision.models.resnet import BasicBlock, Bottleneck
import tensorflow as tf


##transforms input model's network layers to output graph with POET layers for TensorFlow models
def network_transform(net, layers, batch_size, num_classes, input_shape):
for module in net:
if isinstance(module, tf.keras.layers.Dense):
lin_layer = LinearLayer(module.units, module.units, layers[-1])
act_layer = ReLULayer(lin_layer)
layers.extend([lin_layer, act_layer])
if isinstance(module, tf.keras.layers.Activation) and module._name == "relu":
relu_layer = ReLULayer(layers[-1])
layers.append(relu_layer)
if isinstance(module, tf.keras.layers.Conv2D):
if module.padding == "valid":
padding = (0, 0)
elif module.padding == "same":
padding = (1, 1)
conv_layer = Conv2dLayer(1, module.filters, module.kernel_size, module.strides[0], padding, layers[-1])
layers.append(conv_layer)
if isinstance(module, tf.keras.layers.BatchNormalization):
layers.append(BatchNorm2d(layers[-1]))
if isinstance(module, tf.keras.layers.MaxPool2D):
layers.append(MaxPool2d(module.pool_size, module.strides[0], layers[-1]))
if isinstance(module, tf.keras.layers.GlobalAveragePooling2D):
if module.keepdims:
layers.append(GlobalAvgPool(layers[-1]))
if isinstance(module, tf.keras.layers.Activation) and module._name == "tanh":
tanh_layer = TanHLayer(layers[-1])
layers.append(tanh_layer)
if isinstance(module, tf.keras.layers.Activation) and module._name == "sigmoid":
sigmoid_layer = SigmoidLayer(layers[-1])
layers.append(sigmoid_layer)
if isinstance(module, tf.keras.layers.Flatten):
flatten_layer = FlattenLayer(layers[-1])
layers.append(flatten_layer)
if isinstance(module, tf.keras.layers.Dropout):
dropout_layer = DropoutLayer(layers[-1])
layers.append(dropout_layer)
return layers
19 changes: 19 additions & 0 deletions poet/framework_integration/pytorch/test_resnet_graph_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from torch.fx import symbolic_trace
import torchvision
from poet.architectures.graph_transformation import graph_transform

# transforms ResNet Model graph into POET layers nodes

# Resnet18 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917
traced = symbolic_trace(torchvision.models.resnet18(pretrained=True))
poet_traced = graph_transform(traced)
for n in poet_traced.graph.nodes:
print(n.target)
print(n.name)

# Resnet50 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917
traced = symbolic_trace(torchvision.models.resnet50(pretrained=True))
poet_traced = graph_transform(traced)
for n in poet_traced.graph.nodes:
print(n.target)
print(n.name)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from poet.power_computation import InputLayer
import torchvision.models
from poet.architectures.network_transform_pytorch import network_transform

# transforms ResNet Model network layers into POET computation layers

batch_size = (1,)
input_shape = (3, 32, 32)
num_classes = 10
layers = [InputLayer((batch_size, *input_shape))]

# Resnet18 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917
final_layers = network_transform(torchvision.models.resnet18(pretrained=True), layers, batch_size, num_classes, input_shape)
print(final_layers)

# Resnet50 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917
final_layers = network_transform(torchvision.models.resnet50(pretrained=True), layers, batch_size, num_classes, input_shape)
print(final_layers)
16 changes: 16 additions & 0 deletions poet/framework_integration/pytorch/test_vgg_network_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from poet.power_computation import InputLayer
import torch
from poet.architectures.network_transform_pytorch import network_transform

# transforms VGG Model network layers into POET computation layers

batch_size = (1,)
input_shape = (3, 32, 32)
num_classes = 10
layers = [InputLayer((batch_size, *input_shape))]

# VGG16 model transformation - https://download.pytorch.org/models/vgg16-397923af.pth
final_layers = network_transform(
torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True), layers, batch_size, num_classes, input_shape
)
print(final_layers)
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from poet.architectures.network_transform_tensorflow import network_transform
from poet.power_computation import InputLayer
import tensorflow as tf

# transforms ResNet Model network layers into POET computation layers


# defining ResNet model in TensorFlow
def BasicBlock(inputs, num_channels, kernel_size, num_blocks, skip_blocks, name):
"""Basic residual block"""
x = inputs
for i in range(num_blocks):
if i not in skip_blocks:
x1 = ConvNormRelu(x, num_channels, kernel_size, strides=[1, 1], name=name + "." + str(i))
x = tf.keras.layers.Add()([x, x1])
x = tf.keras.layers.Activation("relu")(x)
return x


def BasicBlockDown(inputs, num_channels, kernel_size, name):
"""Residual block with strided downsampling"""
x = inputs
x1 = ConvNormRelu(x, num_channels, kernel_size, strides=[2, 1], name=name + ".0")
x = tf.keras.layers.Conv2D(
num_channels, kernel_size=1, strides=2, padding="same", activation="linear", use_bias=False, name=name + ".0.downsample.0"
)(x)
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".0.downsample.1")(x)
x = tf.keras.layers.Add()([x, x1])
x = tf.keras.layers.Activation("relu")(x)
return x


def ConvNormRelu(x, num_channels, kernel_size, strides, name):
"""Layer consisting of 2 consecutive batch normalizations with 1 first relu"""
if strides[0] == 2:
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name=name + ".pad")(x)
x = tf.keras.layers.Conv2D(
num_channels, kernel_size, strides[0], padding="valid", activation="linear", use_bias=False, name=name + ".conv1"
)(x)
else:
x = tf.keras.layers.Conv2D(
num_channels, kernel_size, strides[0], padding="same", activation="linear", use_bias=False, name=name + ".conv1"
)(x)
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".bn1")(x)
x = tf.keras.layers.Activation("relu")(x)
x = tf.keras.layers.Conv2D(
num_channels, kernel_size, strides[1], padding="same", activation="linear", use_bias=False, name=name + ".conv2"
)(x)
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".bn2")(x)
return x


def ResNet18(inputs):
x = tf.keras.layers.ZeroPadding2D(padding=(3, 3), name="pad")(inputs)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=7, strides=2, padding="valid", activation="linear", use_bias=False, name="conv1")(x)
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name="bn1")(x)
x = tf.keras.layers.Activation("relu", name="relu")(x)
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name="pad1")(x)
x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="maxpool")(x)
x = BasicBlock(x, num_channels=64, kernel_size=3, num_blocks=2, skip_blocks=[], name="layer1")
x = BasicBlockDown(x, num_channels=128, kernel_size=3, name="layer2")
x = BasicBlock(x, num_channels=128, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer2")
x = BasicBlockDown(x, num_channels=256, kernel_size=3, name="layer3")
x = BasicBlock(x, num_channels=256, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer3")
x = BasicBlockDown(x, num_channels=512, kernel_size=3, name="layer4")
x = BasicBlock(x, num_channels=512, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer4")
x = tf.keras.layers.GlobalAveragePooling2D(name="avgpool")(x)
x = tf.keras.layers.Dense(units=1000, use_bias=True, activation="linear", name="fc")(x)
return x


inputs = tf.keras.Input((None, None, 3))
resnet_tf = ResNet18(inputs)
model = tf.keras.Model(inputs, resnet_tf)
batch_size = (1,)
input_shape = (3, 32, 32)
num_classes = 10
layers = [InputLayer((batch_size, *input_shape))]

final_layers = network_transform([layer for layer in model.layers], layers, batch_size, num_classes, input_shape)
print(final_layers)
2 changes: 1 addition & 1 deletion poet/poet_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _initialize_variables(self):

def _create_correctness_constraints(self):
# ensure all computations are possible
for (u, v) in self.g.edge_list:
for u, v in self.g.edge_list:
for t in range(self.T):
self.m += self.R[t][v] <= self.R[t][u] + self.SRam[t][u]
# ensure all checkpoints are in memory
Expand Down
3 changes: 2 additions & 1 deletion poet/poet_solver_gurobi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

# noinspection PyPackageRequirements


# POET ILP defined using Gurobi
class POETSolverGurobi:
def __init__(
Expand Down Expand Up @@ -124,7 +125,7 @@ def _disable_paging(self):

def _create_correctness_constraints(self):
# ensure all computations are possible
for (u, v) in self.g.edge_list:
for u, v in self.g.edge_list:
for t in range(self.T):
self.m.addLConstr(self.R[t, v], GRB.LESS_EQUAL, self.R[t, u] + self.SRam[t, u])
# ensure all checkpoints are in memory
Expand Down
2 changes: 1 addition & 1 deletion poet/utils/checkmate/core/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def edge_to_adj_list(E: EdgeList, convert_undirected=False):
"""Returns an (undirected / bidirectional) adjacency list"""
adj_list = defaultdict(set)
for (i, j) in E:
for i, j in E:
adj_list[i].add(j)
if convert_undirected:
adj_list[j].add(i)
Expand Down
21 changes: 21 additions & 0 deletions test/test_pytorch_resnet_graph_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from torch.fx import symbolic_trace
import torchvision
from poet.architectures.graph_transformation import graph_transform

# transforms ResNet Model graph into POET layers nodes

# Resnet18 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917
traced = symbolic_trace(torchvision.models.resnet18(pretrained=True))
poet_traced = graph_transform(traced)
for n in poet_traced.graph.nodes:
print(n.target)
print(n.name)


# Resnet50 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917
traced = symbolic_trace(torchvision.models.resnet50(pretrained=True))
poet_traced = graph_transform(traced)
for n in poet_traced.graph.nodes:
print(n.target)
print(n.name)
assert n.target == "output" and n.name == "output"
Loading