Adding unpooling layers to a pretrained AlexNet

Hi,

I am trying to add unpooling and deconvolution (transposed convolution) to the pretrained AlexNet. Since I need to make return_indices = True in pooling layers, I think I can’t use Sequential (pooling layers would have two outputs).
Now that I am using functional, the load_state_dict no longer loads the weights from the provided urls because it expects e.g. feature.0.weight instead of conv1.weight. Is there any way to modify the load_state to set the weights with new keys? Or is there any way to use Sequential with return_indices = True in pooling layers?

class AlexNetrec(nn.Module):

def __init__(self, num_classes=1000):
        super(AlexNetrec, self).__init__()       
        self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2)          
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)
        self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)
        self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, return_indices=True)   

        self.unpool3 = nn.MaxUnpool2d(kernel_size=3, stride=2)
        self.deconv5 = nn.ConvTranspose2d(256, 256, kernel_size=3, padding=1)        
        self.deconv4 = nn.ConvTranspose2d(384, 256, kernel_size=3, padding=1)        
        self.deconv3 = nn.ConvTranspose2d(192, 384, kernel_size=3, padding=1)         
        self.unpool2 = nn.MaxUnpool2d(kernel_size=3, stride=2)
        self.deconv2 = nn.ConvTranspose2d(64, 192, kernel_size=5, padding=2)           
        self.unpool1 = nn.MaxUnpool2d(kernel_size=3, stride=2)
        self.deconv1 = nn.ConvTranspose2d(3, 64, kernel_size=11, stride=4, padding=2)   
        

def forward(self, x):
    out1 = F.relu(self.conv1(x))
    out2, indxp1 = self.pool1(out1)
    out3 = F.relu(self.conv2(out2))
    out4, indxp2 = self.pool2(out3)
    out5 = F.relu(self.conv3(out4))
    out6 = F.relu(self.conv4(out5))
    out7 = F.relu(self.conv5(out6))
    out8, indxp3 = self.pool3(out7)
    
    out9 = self.unpool3(out8,indxp3)
    out10 = F.relu(self.deconv5(out9))
    out11 = F.relu(self.deconv4(out10))
    out12 = F.relu(self.deconv3(out11))
    out13 = self.unpool2(out12,indxp2 )
    out14 = F.relu(self.deconv2(out13))
    out15 = self.unpool1(out14,indxp1 )
    out16 = F.relu(self.deconv1(out15))
    return out16

state_dict is basically a dict. So what you can do is define a mapping dict like:

mapping = {'feature.0.weight': 'conv1.weight', ......}

And load like this

state_dict = net.state_dict()
state_dict.update({mapping[k]: v for k,v in pretrained_state_dict.items()}) # Since you have new parameters, you have to update upon the current state_dict
net.load_state_dict(state_dict)
1 Like