-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transforming Models from PyTorch/Tensorflow to POET #9
Open
arnavsinghvi11
wants to merge
5
commits into
ShishirPatil:main
Choose a base branch
from
arnavsinghvi11:transformation_testing
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
853e088
network transformation function & testing for PyTorch to POET
06349a8
network transformation function & testing for TensorFlow to POET
305f861
graph transformation function for converting PyTorch model graph to P…
ce1c44d
updated typing and graph transformation edge case, revised changes fr…
6dbcad4
renamed to framework_integration, adding testing files to test folder
arnavsinghvi11 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
from poet.power_computation import ( | ||
LinearLayer, | ||
ReLULayer, | ||
Conv2dLayer, | ||
FlattenLayer, | ||
TanHLayer, | ||
SigmoidLayer, | ||
SkipAddLayer, | ||
DropoutLayer, | ||
GradientLayer, | ||
InputLayer, | ||
SkipAddLayer, | ||
CrossEntropyLoss, | ||
GradientLayer, | ||
BatchNorm2d, | ||
MaxPool2d, | ||
AvgPool2d, | ||
GlobalAvgPool, | ||
get_net_costs, | ||
) | ||
|
||
|
||
# transforms input model's graph to output graph with POET layer nodes | ||
def graph_transform(traced: torch.fx.graph_module.GraphModule) -> torch.fx.graph_module.GraphModule: | ||
for n in traced.graph.nodes: | ||
# ignores built-in functions and input x which are not layer nodes in the model graph | ||
# ignores poet layer nodes which are added to the model graph from this function | ||
if "<built-in function" in str(n.target) or "poet" in str(n.target) or "x" == str(n.target): | ||
continue | ||
elif "fc" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(LinearLayer, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
elif "flatten" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(FlattenLayer, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
elif "relu" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(ReLULayer, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
elif "conv" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(Conv2dLayer, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
elif "bn" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(BatchNorm2d, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
elif "maxpool" in str(n.target): | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(MaxPool2d, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
traced.graph.erase_node(n) | ||
else: | ||
user_input = input(str(n.target) + " is not supported by POET layers. Would you like to proceed? (y/n)") | ||
if user_input.lower() == "y": | ||
continue | ||
else: | ||
exit(0) | ||
traced.recompile() | ||
return traced |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from poet.power_computation import ( | ||
LinearLayer, | ||
ReLULayer, | ||
Conv2dLayer, | ||
FlattenLayer, | ||
TanHLayer, | ||
SigmoidLayer, | ||
SkipAddLayer, | ||
DropoutLayer, | ||
GradientLayer, | ||
InputLayer, | ||
SkipAddLayer, | ||
CrossEntropyLoss, | ||
GradientLayer, | ||
BatchNorm2d, | ||
MaxPool2d, | ||
AvgPool2d, | ||
GlobalAvgPool, | ||
get_net_costs, | ||
) | ||
import torch.nn as nn | ||
import torchvision.models | ||
from torchvision.models.resnet import BasicBlock, Bottleneck | ||
|
||
|
||
# transforms input model's network layers to output graph with POET layers for PyTorch models | ||
def network_transform(net, layers, batch_size, num_classes, input_shape): | ||
if isinstance(net, torchvision.models.resnet.ResNet): | ||
modules = nn.Sequential(*list(net.children())) | ||
elif isinstance(net, torchvision.models.vgg.VGG): | ||
modules = list(net.children()) | ||
else: | ||
modules = net | ||
for module in modules: | ||
if isinstance(module, nn.Sequential): | ||
sequential_modules = [child for child in module] | ||
input = network_transform(sequential_modules, [layers[-1]], batch_size, num_classes, input_shape) | ||
layers.extend(input[1:]) | ||
if isinstance(module, BasicBlock) or isinstance(module, Bottleneck): | ||
input = network_transform(nn.Sequential(*list(module.children())), [layers[-1]], batch_size, num_classes, input_shape) | ||
layers.extend(input[1:]) | ||
if isinstance(module, nn.Linear): | ||
lin_layer = LinearLayer(module.in_features, module.out_features, layers[-1]) | ||
act_layer = ReLULayer(lin_layer) | ||
layers.extend([lin_layer, act_layer]) | ||
if isinstance(module, nn.ReLU): | ||
relu_layer = ReLULayer(layers[-1]) | ||
layers.append(relu_layer) | ||
if isinstance(module, nn.Conv2d): | ||
conv_layer = Conv2dLayer( | ||
module.in_channels, module.out_channels, module.kernel_size, module.stride[0], module.padding, layers[-1] | ||
) | ||
layers.append(conv_layer) | ||
if isinstance(module, nn.BatchNorm2d): | ||
layers.append(BatchNorm2d(layers[-1])) | ||
if isinstance(module, nn.MaxPool2d): | ||
layers.append(MaxPool2d((module.kernel_size, module.kernel_size), module.stride, layers[-1])) | ||
if isinstance(module, nn.AvgPool2d): | ||
layers.append(AvgPool2d(module.kernel_size, module.stride, layers[-1])) | ||
if isinstance(module, nn.Tanh): | ||
tanh_layer = TanHLayer(layers[-1]) | ||
layers.append(tanh_layer) | ||
if isinstance(module, nn.Sigmoid): | ||
sigmoid_layer = SigmoidLayer(layers[-1]) | ||
layers.append(sigmoid_layer) | ||
if isinstance(module, nn.Flatten): | ||
flatten_layer = FlattenLayer(layers[-1]) | ||
layers.append(flatten_layer) | ||
if isinstance(module, nn.Dropout): | ||
dropout_layer = DropoutLayer(layers[-1]) | ||
layers.append(dropout_layer) | ||
return layers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from poet.power_computation import ( | ||
LinearLayer, | ||
ReLULayer, | ||
Conv2dLayer, | ||
FlattenLayer, | ||
TanHLayer, | ||
SigmoidLayer, | ||
SkipAddLayer, | ||
DropoutLayer, | ||
GradientLayer, | ||
InputLayer, | ||
SkipAddLayer, | ||
CrossEntropyLoss, | ||
GradientLayer, | ||
BatchNorm2d, | ||
MaxPool2d, | ||
AvgPool2d, | ||
GlobalAvgPool, | ||
get_net_costs, | ||
) | ||
from poet.power_computation_transformer import QueryKeyValueMatrix, QKTMatrix, QKTVMatrix | ||
from torchvision.models.resnet import BasicBlock, Bottleneck | ||
import tensorflow as tf | ||
|
||
|
||
##transforms input model's network layers to output graph with POET layers for TensorFlow models | ||
def network_transform(net, layers, batch_size, num_classes, input_shape): | ||
for module in net: | ||
if isinstance(module, tf.keras.layers.Dense): | ||
lin_layer = LinearLayer(module.units, module.units, layers[-1]) | ||
act_layer = ReLULayer(lin_layer) | ||
layers.extend([lin_layer, act_layer]) | ||
if isinstance(module, tf.keras.layers.Activation) and module._name == "relu": | ||
relu_layer = ReLULayer(layers[-1]) | ||
layers.append(relu_layer) | ||
if isinstance(module, tf.keras.layers.Conv2D): | ||
if module.padding == "valid": | ||
padding = (0, 0) | ||
elif module.padding == "same": | ||
padding = (1, 1) | ||
conv_layer = Conv2dLayer(1, module.filters, module.kernel_size, module.strides[0], padding, layers[-1]) | ||
layers.append(conv_layer) | ||
if isinstance(module, tf.keras.layers.BatchNormalization): | ||
layers.append(BatchNorm2d(layers[-1])) | ||
if isinstance(module, tf.keras.layers.MaxPool2D): | ||
layers.append(MaxPool2d(module.pool_size, module.strides[0], layers[-1])) | ||
if isinstance(module, tf.keras.layers.GlobalAveragePooling2D): | ||
if module.keepdims: | ||
layers.append(GlobalAvgPool(layers[-1])) | ||
if isinstance(module, tf.keras.layers.Activation) and module._name == "tanh": | ||
tanh_layer = TanHLayer(layers[-1]) | ||
layers.append(tanh_layer) | ||
if isinstance(module, tf.keras.layers.Activation) and module._name == "sigmoid": | ||
sigmoid_layer = SigmoidLayer(layers[-1]) | ||
layers.append(sigmoid_layer) | ||
if isinstance(module, tf.keras.layers.Flatten): | ||
flatten_layer = FlattenLayer(layers[-1]) | ||
layers.append(flatten_layer) | ||
if isinstance(module, tf.keras.layers.Dropout): | ||
dropout_layer = DropoutLayer(layers[-1]) | ||
layers.append(dropout_layer) | ||
return layers |
19 changes: 19 additions & 0 deletions
19
poet/framework_integration/pytorch/test_resnet_graph_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from torch.fx import symbolic_trace | ||
import torchvision | ||
from poet.architectures.graph_transformation import graph_transform | ||
|
||
# transforms ResNet Model graph into POET layers nodes | ||
|
||
# Resnet18 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 | ||
traced = symbolic_trace(torchvision.models.resnet18(pretrained=True)) | ||
poet_traced = graph_transform(traced) | ||
for n in poet_traced.graph.nodes: | ||
print(n.target) | ||
print(n.name) | ||
|
||
# Resnet50 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 | ||
traced = symbolic_trace(torchvision.models.resnet50(pretrained=True)) | ||
poet_traced = graph_transform(traced) | ||
for n in poet_traced.graph.nodes: | ||
print(n.target) | ||
print(n.name) |
18 changes: 18 additions & 0 deletions
18
poet/framework_integration/pytorch/test_resnet_network_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from poet.power_computation import InputLayer | ||
import torchvision.models | ||
from poet.architectures.network_transform_pytorch import network_transform | ||
|
||
# transforms ResNet Model network layers into POET computation layers | ||
|
||
batch_size = (1,) | ||
input_shape = (3, 32, 32) | ||
num_classes = 10 | ||
layers = [InputLayer((batch_size, *input_shape))] | ||
|
||
# Resnet18 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 | ||
final_layers = network_transform(torchvision.models.resnet18(pretrained=True), layers, batch_size, num_classes, input_shape) | ||
print(final_layers) | ||
|
||
# Resnet50 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 | ||
final_layers = network_transform(torchvision.models.resnet50(pretrained=True), layers, batch_size, num_classes, input_shape) | ||
print(final_layers) |
16 changes: 16 additions & 0 deletions
16
poet/framework_integration/pytorch/test_vgg_network_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from poet.power_computation import InputLayer | ||
import torch | ||
from poet.architectures.network_transform_pytorch import network_transform | ||
|
||
# transforms VGG Model network layers into POET computation layers | ||
|
||
batch_size = (1,) | ||
input_shape = (3, 32, 32) | ||
num_classes = 10 | ||
layers = [InputLayer((batch_size, *input_shape))] | ||
|
||
# VGG16 model transformation - https://download.pytorch.org/models/vgg16-397923af.pth | ||
final_layers = network_transform( | ||
torch.hub.load("pytorch/vision:v0.10.0", "vgg16", pretrained=True), layers, batch_size, num_classes, input_shape | ||
) | ||
print(final_layers) |
81 changes: 81 additions & 0 deletions
81
poet/framework_integration/tensorflow/test_resnet_network_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from poet.architectures.network_transform_tensorflow import network_transform | ||
from poet.power_computation import InputLayer | ||
import tensorflow as tf | ||
|
||
# transforms ResNet Model network layers into POET computation layers | ||
|
||
|
||
# defining ResNet model in TensorFlow | ||
def BasicBlock(inputs, num_channels, kernel_size, num_blocks, skip_blocks, name): | ||
"""Basic residual block""" | ||
x = inputs | ||
for i in range(num_blocks): | ||
if i not in skip_blocks: | ||
x1 = ConvNormRelu(x, num_channels, kernel_size, strides=[1, 1], name=name + "." + str(i)) | ||
x = tf.keras.layers.Add()([x, x1]) | ||
x = tf.keras.layers.Activation("relu")(x) | ||
return x | ||
|
||
|
||
def BasicBlockDown(inputs, num_channels, kernel_size, name): | ||
"""Residual block with strided downsampling""" | ||
x = inputs | ||
x1 = ConvNormRelu(x, num_channels, kernel_size, strides=[2, 1], name=name + ".0") | ||
x = tf.keras.layers.Conv2D( | ||
num_channels, kernel_size=1, strides=2, padding="same", activation="linear", use_bias=False, name=name + ".0.downsample.0" | ||
)(x) | ||
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".0.downsample.1")(x) | ||
x = tf.keras.layers.Add()([x, x1]) | ||
x = tf.keras.layers.Activation("relu")(x) | ||
return x | ||
|
||
|
||
def ConvNormRelu(x, num_channels, kernel_size, strides, name): | ||
"""Layer consisting of 2 consecutive batch normalizations with 1 first relu""" | ||
if strides[0] == 2: | ||
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name=name + ".pad")(x) | ||
x = tf.keras.layers.Conv2D( | ||
num_channels, kernel_size, strides[0], padding="valid", activation="linear", use_bias=False, name=name + ".conv1" | ||
)(x) | ||
else: | ||
x = tf.keras.layers.Conv2D( | ||
num_channels, kernel_size, strides[0], padding="same", activation="linear", use_bias=False, name=name + ".conv1" | ||
)(x) | ||
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".bn1")(x) | ||
x = tf.keras.layers.Activation("relu")(x) | ||
x = tf.keras.layers.Conv2D( | ||
num_channels, kernel_size, strides[1], padding="same", activation="linear", use_bias=False, name=name + ".conv2" | ||
)(x) | ||
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name=name + ".bn2")(x) | ||
return x | ||
|
||
|
||
def ResNet18(inputs): | ||
x = tf.keras.layers.ZeroPadding2D(padding=(3, 3), name="pad")(inputs) | ||
x = tf.keras.layers.Conv2D(filters=64, kernel_size=7, strides=2, padding="valid", activation="linear", use_bias=False, name="conv1")(x) | ||
x = tf.keras.layers.BatchNormalization(momentum=0.1, epsilon=1e-5, name="bn1")(x) | ||
x = tf.keras.layers.Activation("relu", name="relu")(x) | ||
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name="pad1")(x) | ||
x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding="valid", name="maxpool")(x) | ||
x = BasicBlock(x, num_channels=64, kernel_size=3, num_blocks=2, skip_blocks=[], name="layer1") | ||
x = BasicBlockDown(x, num_channels=128, kernel_size=3, name="layer2") | ||
x = BasicBlock(x, num_channels=128, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer2") | ||
x = BasicBlockDown(x, num_channels=256, kernel_size=3, name="layer3") | ||
x = BasicBlock(x, num_channels=256, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer3") | ||
x = BasicBlockDown(x, num_channels=512, kernel_size=3, name="layer4") | ||
x = BasicBlock(x, num_channels=512, kernel_size=3, num_blocks=2, skip_blocks=[0], name="layer4") | ||
x = tf.keras.layers.GlobalAveragePooling2D(name="avgpool")(x) | ||
x = tf.keras.layers.Dense(units=1000, use_bias=True, activation="linear", name="fc")(x) | ||
return x | ||
|
||
|
||
inputs = tf.keras.Input((None, None, 3)) | ||
resnet_tf = ResNet18(inputs) | ||
model = tf.keras.Model(inputs, resnet_tf) | ||
batch_size = (1,) | ||
input_shape = (3, 32, 32) | ||
num_classes = 10 | ||
layers = [InputLayer((batch_size, *input_shape))] | ||
|
||
final_layers = network_transform([layer for layer in model.layers], layers, batch_size, num_classes, input_shape) | ||
print(final_layers) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from torch.fx import symbolic_trace | ||
import torchvision | ||
from poet.architectures.graph_transformation import graph_transform | ||
|
||
# transforms ResNet Model graph into POET layers nodes | ||
|
||
# Resnet18 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 | ||
traced = symbolic_trace(torchvision.models.resnet18(pretrained=True)) | ||
poet_traced = graph_transform(traced) | ||
for n in poet_traced.graph.nodes: | ||
print(n.target) | ||
print(n.name) | ||
|
||
|
||
# Resnet50 model transformation - https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py, commit: 7dc5e5bd60b55eb4e6ea5c1265d6dc7b17d2e917 | ||
traced = symbolic_trace(torchvision.models.resnet50(pretrained=True)) | ||
poet_traced = graph_transform(traced) | ||
for n in poet_traced.graph.nodes: | ||
print(n.target) | ||
print(n.name) | ||
assert n.target == "output" and n.name == "output" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We say network transform, but this is mostly for Resnet+common layers? I think we should have a better way to organize this. I don't know exactly how - open to suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the goal was to extend this for other models like BERT which I didn't fully support in my implementation yet. I can comment out this import for now.