-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transforming Models from PyTorch/Tensorflow to POET #9
base: main
Are you sure you want to change the base?
Changes from 3 commits
853e088
06349a8
305f861
ce1c44d
6dbcad4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from poet.power_computation import ( | ||
LinearLayer, | ||
ReLULayer, | ||
Conv2dLayer, | ||
FlattenLayer, | ||
TanHLayer, | ||
SigmoidLayer, | ||
SkipAddLayer, | ||
DropoutLayer, | ||
GradientLayer, | ||
InputLayer, | ||
SkipAddLayer, | ||
CrossEntropyLoss, | ||
GradientLayer, | ||
BatchNorm2d, | ||
MaxPool2d, | ||
AvgPool2d, | ||
GlobalAvgPool, | ||
get_net_costs, | ||
) | ||
|
||
|
||
# transforms input model's graph to output graph with POET layer nodes | ||
def graph_transform(traced): | ||
for n in traced.graph.nodes: | ||
if "<built-in function" in str(n.target): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed, would adding a print statement informing the layer is not supported by POET and continuing looping through suffice? |
||
continue | ||
elif "fc" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(LinearLayer, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
elif "flatten" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(FlattenLayer, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
elif "relu" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(ReLULayer, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
elif "conv" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(Conv2dLayer, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
elif "bn" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(BatchNorm2d, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
elif "maxpool" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(MaxPool2d, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
traced.recompile() | ||
return traced |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from poet.power_computation import ( | ||
LinearLayer, | ||
ReLULayer, | ||
Conv2dLayer, | ||
FlattenLayer, | ||
TanHLayer, | ||
SigmoidLayer, | ||
SkipAddLayer, | ||
DropoutLayer, | ||
GradientLayer, | ||
InputLayer, | ||
SkipAddLayer, | ||
CrossEntropyLoss, | ||
GradientLayer, | ||
BatchNorm2d, | ||
MaxPool2d, | ||
AvgPool2d, | ||
GlobalAvgPool, | ||
get_net_costs, | ||
) | ||
from poet.power_computation_transformer import QueryKeyValueMatrix, QKTMatrix, QKTVMatrix | ||
import torch.nn as nn | ||
import torchvision.models | ||
from torchvision.models.resnet import BasicBlock, Bottleneck | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We say network transform, but this is mostly for Resnet+common layers? I think we should have a better way to organize this. I don't know exactly how - open to suggestions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the goal was to extend this for other models like BERT which I didn't fully support in my implementation yet. I can comment out this import for now. |
||
|
||
|
||
# transforms input model's network layers to output graph with POET layers for PyTorch models | ||
def network_transform(net, layers, batch_size, num_classes, input_shape): | ||
if isinstance(net, torchvision.models.resnet.ResNet): | ||
modules = nn.Sequential(*list(net.children())) | ||
elif isinstance(net, torchvision.models.vgg.VGG): | ||
modules = list(net.children()) | ||
else: | ||
modules = net | ||
for module in modules: | ||
if isinstance(module, nn.Sequential): | ||
sequential_modules = [child for child in module] | ||
input = network_transform(sequential_modules, [layers[-1]], batch_size, num_classes, input_shape) | ||
layers.extend(input[1:]) | ||
if isinstance(module, BasicBlock) or isinstance(module, Bottleneck): | ||
input = network_transform(nn.Sequential(*list(module.children())), [layers[-1]], batch_size, num_classes, input_shape) | ||
layers.extend(input[1:]) | ||
if isinstance(module, nn.Linear): | ||
lin_layer = LinearLayer(module.in_features, module.out_features, layers[-1]) | ||
act_layer = ReLULayer(lin_layer) | ||
layers.extend([lin_layer, act_layer]) | ||
if isinstance(module, nn.ReLU): | ||
relu_layer = ReLULayer(layers[-1]) | ||
layers.append(relu_layer) | ||
if isinstance(module, nn.Conv2d): | ||
conv_layer = Conv2dLayer( | ||
module.in_channels, module.out_channels, module.kernel_size, module.stride[0], module.padding, layers[-1] | ||
) | ||
layers.append(conv_layer) | ||
if isinstance(module, nn.BatchNorm2d): | ||
layers.append(BatchNorm2d(layers[-1])) | ||
if isinstance(module, nn.MaxPool2d): | ||
layers.append(MaxPool2d((module.kernel_size, module.kernel_size), module.stride, layers[-1])) | ||
if isinstance(module, nn.AvgPool2d): | ||
layers.append(AvgPool2d(module.kernel_size, module.stride, layers[-1])) | ||
if isinstance(module, nn.Tanh): | ||
tanh_layer = TanHLayer(layers[-1]) | ||
layers.append(tanh_layer) | ||
if isinstance(module, nn.Sigmoid): | ||
sigmoid_layer = SigmoidLayer(layers[-1]) | ||
layers.append(sigmoid_layer) | ||
if isinstance(module, nn.Flatten): | ||
flatten_layer = FlattenLayer(layers[-1]) | ||
layers.append(flatten_layer) | ||
if isinstance(module, nn.Dropout): | ||
dropout_layer = DropoutLayer(layers[-1]) | ||
layers.append(dropout_layer) | ||
return layers |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from poet.power_computation import ( | ||
LinearLayer, | ||
ReLULayer, | ||
Conv2dLayer, | ||
FlattenLayer, | ||
TanHLayer, | ||
SigmoidLayer, | ||
SkipAddLayer, | ||
DropoutLayer, | ||
GradientLayer, | ||
InputLayer, | ||
SkipAddLayer, | ||
CrossEntropyLoss, | ||
GradientLayer, | ||
BatchNorm2d, | ||
MaxPool2d, | ||
AvgPool2d, | ||
GlobalAvgPool, | ||
get_net_costs, | ||
) | ||
from poet.power_computation_transformer import QueryKeyValueMatrix, QKTMatrix, QKTVMatrix | ||
from torchvision.models.resnet import BasicBlock, Bottleneck | ||
import tensorflow as tf | ||
|
||
|
||
##transforms input model's network layers to output graph with POET layers for TensorFlow models | ||
def network_transform(net, layers, batch_size, num_classes, input_shape): | ||
for module in net: | ||
if isinstance(module, tf.keras.layers.Dense): | ||
lin_layer = LinearLayer(module.units, module.units, layers[-1]) | ||
act_layer = ReLULayer(lin_layer) | ||
layers.extend([lin_layer, act_layer]) | ||
if isinstance(module, tf.keras.layers.Activation) and module._name == "relu": | ||
relu_layer = ReLULayer(layers[-1]) | ||
layers.append(relu_layer) | ||
if isinstance(module, tf.keras.layers.Conv2D): | ||
if module.padding == "valid": | ||
padding = (0, 0) | ||
elif module.padding == "same": | ||
padding = (1, 1) | ||
conv_layer = Conv2dLayer(1, module.filters, module.kernel_size, module.strides[0], padding, layers[-1]) | ||
layers.append(conv_layer) | ||
if isinstance(module, tf.keras.layers.BatchNormalization): | ||
layers.append(BatchNorm2d(layers[-1])) | ||
if isinstance(module, tf.keras.layers.MaxPool2D): | ||
layers.append(MaxPool2d(module.pool_size, module.strides[0], layers[-1])) | ||
if isinstance(module, tf.keras.layers.GlobalAveragePooling2D): | ||
if module.keepdims: | ||
layers.append(GlobalAvgPool(layers[-1])) | ||
if isinstance(module, tf.keras.layers.Activation) and module._name == "tanh": | ||
tanh_layer = TanHLayer(layers[-1]) | ||
layers.append(tanh_layer) | ||
if isinstance(module, tf.keras.layers.Activation) and module._name == "sigmoid": | ||
sigmoid_layer = SigmoidLayer(layers[-1]) | ||
layers.append(sigmoid_layer) | ||
if isinstance(module, tf.keras.layers.Flatten): | ||
flatten_layer = FlattenLayer(layers[-1]) | ||
layers.append(flatten_layer) | ||
if isinstance(module, tf.keras.layers.Dropout): | ||
dropout_layer = DropoutLayer(layers[-1]) | ||
layers.append(dropout_layer) | ||
return layers |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from torch.fx import symbolic_trace | ||
import torchvision | ||
from poet.architectures.graph_transformation import graph_transform | ||
|
||
# transforms ResNet Model graph into POET layers nodes | ||
|
||
# #Resnet18 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this commented for a reason? Good to have the tests in. |
||
# traced = symbolic_trace(torchvision.models.resnet18(pretrained=True)) | ||
# poet_traced = graph_transform(traced) | ||
# for n in poet_traced.graph.nodes: | ||
# print(n.target) | ||
# print(n.name) | ||
|
||
# #Resnet50 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 | ||
# traced = symbolic_trace(torchvision.models.resnet50(pretrained=True)) | ||
# poet_traced = graph_transform(traced) | ||
# for n in poet_traced.graph.nodes: | ||
# print(n.target) | ||
# print(n.name) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from poet.power_computation import InputLayer | ||
import torchvision.models | ||
from poet.architectures.network_transform_pytorch import network_transform | ||
|
||
# transforms ResNet Model network layers into POET computation layers | ||
|
||
batch_size = (1,) | ||
input_shape = (3, 32, 32) | ||
num_classes = 10 | ||
layers = [InputLayer((batch_size, *input_shape))] | ||
|
||
# comment out to output transformations for different ResNet models | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Slightly confused by the comment message. |
||
|
||
# #Resnet18 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 | ||
# final_layers = network_transform(torchvision.models.resnet18(pretrained=True), layers, batch_size, num_classes, input_shape) | ||
# print(final_layers) | ||
|
||
# #Resnet50 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 | ||
# final_layers = network_transform(torchvision.models.resnet50(pretrained=True), layers, batch_size, num_classes, input_shape) | ||
# print(final_layers) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from poet.power_computation import InputLayer | ||
import torch | ||
from poet.architectures.network_transform_pytorch import network_transform | ||
|
||
# transforms VGG Model network layers into POET computation layers | ||
|
||
batch_size = (1,) | ||
input_shape = (3, 32, 32) | ||
num_classes = 10 | ||
layers = [InputLayer((batch_size, *input_shape))] | ||
|
||
# VGG16 model transformation - https://download.pytorch.org/models/vgg16-397923af.pth | ||
final_layers = network_transform( | ||
torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True), layers, batch_size, num_classes, input_shape | ||
) | ||
print(final_layers) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from poet.architectures.network_transform_tensorflow import network_transform | ||
from poet.power_computation import InputLayer | ||
import tensorflow as tf | ||
|
||
# transforms VGG Model network layers into POET computation layers | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment reads VGG Model but I think the ResNet model is being constructed here? |
||
|
||
|
||
# defining ResNet model in TensorFlow | ||
def BasicBlock(inputs, num_channels, kernel_size, num_blocks, skip_blocks, name): | ||
"""Basic residual block""" | ||
x = inputs | ||
for i in range(num_blocks): | ||
if i not in skip_blocks: | ||
x1 = ConvNormRelu(x, num_channels, kernel_size, strides=[1, 1], name=name + "." + str(i)) | ||
x = tf.keras.layers.Add()([x, x1]) | ||
x = tf.keras.layers.Activation("relu")(x) | ||
return x | ||
|
||
|
||
def BasicBlockDown(inputs, num_channels, kernel_size, name): | ||
"""Residual block with strided downsampling""" | ||
x = inputs | ||
x1 = ConvNormRelu(x, num_channels, kernel_size, strides=[2, 1], name=name + ".0") | ||
x = tf.keras.layers.Conv2D( | ||
num_channels, kernel_size=1, strides=2, padding="same", activation="linear", use_bias=False, name=name + ".0.downsample.0" | ||
)(x) | ||
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".0.downsample.1")(x) | ||
x = tf.keras.layers.Add()([x, x1]) | ||
x = tf.keras.layers.Activation("relu")(x) | ||
return x | ||
|
||
|
||
def ConvNormRelu(x, num_channels, kernel_size, strides, name): | ||
"""Layer consisting of 2 consecutive batch normalizations with 1 first relu""" | ||
if strides[0] == 2: | ||
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name=name + ".pad")(x) | ||
x = tf.keras.layers.Conv2D( | ||
num_channels, kernel_size, strides[0], padding="valid", activation="linear", use_bias=False, name=name + ".conv1" | ||
)(x) | ||
else: | ||
x = tf.keras.layers.Conv2D( | ||
num_channels, kernel_size, strides[0], padding="same", activation="linear", use_bias=False, name=name + ".conv1" | ||
)(x) | ||
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".bn1")(x) | ||
x = tf.keras.layers.Activation("relu")(x) | ||
x = tf.keras.layers.Conv2D( | ||
num_channels, kernel_size, strides[1], padding="same", activation="linear", use_bias=False, name=name + ".conv2" | ||
)(x) | ||
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".bn2")(x) | ||
return x | ||
|
||
|
||
def ResNet18(inputs): | ||
x = tf.keras.layers.ZeroPadding2D(padding=(3, 3), name="pad")(inputs) | ||
x = tf.keras.layers.Conv2D(filters=64, kernel_size=7, strides=2, padding="valid", activation="linear", use_bias=False, name="conv1")(x) | ||
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name="bn1")(x) | ||
x = tf.keras.layers.Activation("relu", name="relu")(x) | ||
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name="pad1")(x) | ||
x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="maxpool")(x) | ||
x = BasicBlock(x, num_channels=64, kernel_size=3, num_blocks=2, skip_blocks=[], name="layer1") | ||
x = BasicBlockDown(x, num_channels=128, kernel_size=3, name="layer2") | ||
x = BasicBlock(x, num_channels=128, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer2") | ||
x = BasicBlockDown(x, num_channels=256, kernel_size=3, name="layer3") | ||
x = BasicBlock(x, num_channels=256, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer3") | ||
x = BasicBlockDown(x, num_channels=512, kernel_size=3, name="layer4") | ||
x = BasicBlock(x, num_channels=512, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer4") | ||
x = tf.keras.layers.GlobalAveragePooling2D(name="avgpool")(x) | ||
x = tf.keras.layers.Dense(units=1000, use_bias=True, activation="linear", name="fc")(x) | ||
return x | ||
|
||
|
||
inputs = tf.keras.Input((None, None, 3)) | ||
resnet_tf = ResNet18(inputs) | ||
model = tf.keras.Model(inputs, resnet_tf) | ||
batch_size = (1,) | ||
input_shape = (3, 32, 32) | ||
num_classes = 10 | ||
layers = [InputLayer((batch_size, *input_shape))] | ||
|
||
final_layers = network_transform([layer for layer in model.layers], layers, batch_size, num_classes, input_shape) | ||
print(final_layers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve readability can we use type hints? As type
torch.nn
as input and as typepoet.DNNLayer
as expected output?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function actually takes in the graph of a torch model and outputs a recompiled version of the graph with poet.DNNLayers. I printed the type of this graph which is: torch.fx.graph_module.GraphModule.new..GraphModuleImpl so do you still recommend including this as a type hint or would that stray away from readability?