class Encoder(nn.Module):
def init(self, enc_image_size=14):
super(Encoder, self).init()
self.enc_image_size = enc_image_size
resnet = torchvision.models.resnet101(pretrained=True)
modules = list(resnet.children())[:-2] # remove linear and pool layers
self.resnet = nn.Sequential(*modules)
self.adaptive_pool = nn.AdaptiveAvgPool2d((enc_image_size, enc_image_size)) # resize image to fixed size using adaptive pool for variable size input image
self.fine_tune()
def fine_tune(self, fine_tune=True):
for param in self.resnet.parameters():
param.requires_grad = False
for child in list(self.resnet.children())[5:]:
for param in child.parameters():
param.requires_grad = fine_tune
def forward(self, images):
out = self.resnet(images) # (batch_size, 2048, image_size/32, image_size/32)
out = self.adaptive_pool(out) # (batch_size, 2048, enc_image_size, enc_image_size)
out = out.permute(0, 2, 3, 1) # (batch_size, enc_image_size, enc_image_size, 2048)
return out