Hi,
I am trying to add unpooling and deconvolution (transposed convolution) to the pretrained AlexNet. Since I need to make return_indices = True in pooling layers, I think I can’t use Sequential (pooling layers would have two outputs).
Now that I am using functional, the load_state_dict no longer loads the weights from the provided urls because it expects e.g. feature.0.weight instead of conv1.weight. Is there any way to modify the load_state to set the weights with new keys? Or is there any way to use Sequential with return_indices = True in pooling layers?
class AlexNetrec(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNetrec, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2)
self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)
self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2)
self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)
self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)
self.unpool3 = nn.MaxUnpool2d(kernel_size=3, stride=2)
self.deconv5 = nn.ConvTranspose2d(256, 256, kernel_size=3, padding=1)
self.deconv4 = nn.ConvTranspose2d(384, 256, kernel_size=3, padding=1)
self.deconv3 = nn.ConvTranspose2d(192, 384, kernel_size=3, padding=1)
self.unpool2 = nn.MaxUnpool2d(kernel_size=3, stride=2)
self.deconv2 = nn.ConvTranspose2d(64, 192, kernel_size=5, padding=2)
self.unpool1 = nn.MaxUnpool2d(kernel_size=3, stride=2)
self.deconv1 = nn.ConvTranspose2d(3, 64, kernel_size=11, stride=4, padding=2)
def forward(self, x):
out1 = F.relu(self.conv1(x))
out2, indxp1 = self.pool1(out1)
out3 = F.relu(self.conv2(out2))
out4, indxp2 = self.pool2(out3)
out5 = F.relu(self.conv3(out4))
out6 = F.relu(self.conv4(out5))
out7 = F.relu(self.conv5(out6))
out8, indxp3 = self.pool3(out7)
out9 = self.unpool3(out8,indxp3)
out10 = F.relu(self.deconv5(out9))
out11 = F.relu(self.deconv4(out10))
out12 = F.relu(self.deconv3(out11))
out13 = self.unpool2(out12,indxp2 )
out14 = F.relu(self.deconv2(out13))
out15 = self.unpool1(out14,indxp1 )
out16 = F.relu(self.deconv1(out15))
return out16