-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencoder.py
More file actions
39 lines (32 loc) · 1.35 KB
/
encoder.py
File metadata and controls
39 lines (32 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn as nn
import torchvision.models as models
class EncoderCNN(nn.Module):
"""
EncoderCNN
"""
def __init__(self, model="vgg16"):
"""
:param model: name of pre-trained CNN to use
"""
super(EncoderCNN, self).__init__()
if model == "vgg16":
features = models.vgg16(pretrained=True).features
self.cnn = nn.Sequential(*list(features))
elif model == "resnet50":
layers = models.resnet50(pretrained=True).children()
self.cnn = nn.Sequential(*list(layers)[:-2])
elif model == "inception_v3":
layers = models.inception_v3(pretrained=True, aux_logits=False).children()
self.cnn = nn.Sequential(*list(layers)[:-3])
for param in self.cnn.parameters():
param.requires_grad_(False)
def forward(self, images):
"""
Forward propagation.
:param images: images, a tensor of size (batch_size, 3, image_size, image_size)
:return: encoded images
"""
features = self.cnn(images) # (batch, channels, h, w)
features = features.permute(0, 2, 3, 1) # (batch, h, w, channels)
features = features.view(features.size(0), -1, features.size(-1)) # (batch, h * w, channels)
return features