Resnet last layer modification

Hello guys, I’m trying to add a dropout layer before the FC layer in the “bottom” of my resnet. So, in order to do that, I remove the original FC layer from the resnet18 with the following code:

    resnetk = models.resnet18(pretrained=True)
    num_ftrs = resnetk.fc.in_features
    resnetk = torch.nn.Sequential(*list(resnetk.children())[:-1])

Then, I add the dropout and the FC layer using the num_ftrs I obtained from the previous (original) FC layer of my resnet18:

    resnetk.add_module("dropout", nn.Dropout(p=0.5))
    resnetk.add_module("fc", nn.Linear(num_ftrs, n_classes))

But I receive the following error : RuntimeError: size mismatch, m1: [8192 x 1], m2: [512 x 2] at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:266

I’m also confused where the softmax gets in, after the linear layer, since in Keras we need to specify the activation function as softmax.

2 Likes

Currently you are rewrapping your pretrained resnet into a new nn.Sequential module, which will lose the forward definition. As you can see in this line of code in the original resnet implementation, the activation x will be flattened before being passed to the last linear layer. Since this is missing now, you’ll get the size mismatch error.

You could just manipulate the model.fc attribute and add your dropout and linear layer:

model = models.resnet18(pretrained=True)
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(num_ftrs, 10)
)

This approach will keep your forward method.

It depends on the criterion you are using if you should add a final non-linearity in your model.
If you are dealing with a classification use case and would like to use nn.CrossEntropyLoss, you should pass the raw logits to the loss function (i.e. no final non-linearity), since nn.LogSoftmax and nn.NLLLoss will be called internally.
However, if you would like to use nn.NLLLoss, you should add the nn.LogSoftmax manually.

16 Likes

Thanks @ptrblck, your comments and observations helped me to understand my problem. And I appreciate that you pointed the line of code in the original resnet implementation, it helped even more!

1 Like

Hi,
Isn’t it like adding dropout to the last layer? Why do we want to drop out some outputs?

Regards,

No, as the dropout layer is added before the output layer not after.

1 Like

I am very new in this field. Would you agree with me if I say: since it’s sequentially built, whatever written first is applied before. And functional.dropout of a layer is applied after because in the forward method, we are calling the layers one after another, so dropout is required after the layer has been activated?

The execution order of each module in an nn.Sequential container is the same order their forward methods will be called.
Here is a small example, which shows that the first case yields a dense output, while the second one zeroes out some output units:

# Case 1
model = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(10, 10)
)

x = torch.randn(1, 10)
out = model(x)
print(out)

# Case 2
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.Dropout(0.5)
)

x = torch.randn(1, 10)
out = model(x)
print(out)

Thank you so much for this example.
Could you tell me what we are dropping in Resnet before the last layer because there is no linear layer before the last layer but a convolutional layer?

Regards

In this example, you would drop some of the features, which are fed into the linear layer.
I just provided the code snippet for @Paulo_Mann, so he might know better what his use case is. :wink:

1 Like

I always thought that Dropout is meant to drop some neurons in a Linear layer to reduce the outgoing features. I did not know that it is used to drop in-coming features as well.

Thank you very much for this information. Also, if you have a source handy where I can read more about dropouts’ this behavior, could you please provide :slight_smile:

BR,

Can I ask, when I tried to add another additional linear layer, but it came out with

RuntimeError: Error(s) in loading statam e_dict for ResNet: size mismatch for fc.2.weight: copying a param with shape torch.Size([20, 100]) from checkpoint, the shape in current model is torch.Size([60, 100]). size mismatch for fc.2.bias: copying a parwith shape torch.Size([20]) from checkpoint, the shape in current model is torch.Size([60]). size mismatch for fc.5.weight: copying a param with shape torch.Size([4, 20]) from checkpoint, the shape in current model is torch.Size([20]). size mismatch for fc.5.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([20]).

I’ve already state strict = False, when i load my model. What else could I be doing wrong?

def initModel4():

resnetbase = models.resnet34(pretrained=True).to(device)

for param in resnetbase.parameters(): #Freeze all the layers and train only the last layer

param.requires_grad = False

# add in dropout layer

fc_layers = nn.Sequential(

nn.Linear(resnetbase.fc.in_features,100),

nn.Hardswish(),

nn.Linear(100, 20),

nn.BatchNorm1d(20),

nn.Dropout(0.12),

nn.Linear(20, 10),

nn.Linear(10, 4),

).to(device)

##weights for background imgs, tanks, floating head tanks, tank clusters

class_weights = torch.FloatTensor([0.1,0.5,0.3,2.0]).to(device)

loss_function = nn.CrossEntropyLoss(torch.FloatTensor([0.2,0.6,0.4,2.0])).to(device)

optimizer = torch.optim.Adam(resnetbase.parameters(), lr=lr)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

return

If you want to change some layers in a pretrained model, make sure to load the state_dict containing the parameters and buffers in the expected shapes before appyling any manipulations, such as assigning new layers to attributes.
strict=False wouldn’t work in this case, as there are no key mismatches, but the shapes are wrong.

PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier :wink:

2 Likes

Thank you once again! Apologies for the late reply as I was working on other work for awhile.
And opps! Thanks for the tip! Will take note of it next time!:slight_smile:

Hi Ptrblck,

I am trying to add extra layer like Resnet, to the input layer and then pass it to the next layer. (not exact as Resnet, by adding conv3d (self.l33 and self.l44) between my generator layers, I want to add more features to the generator)

Should I use "Conv3d or Conv1d " in the last layer (self.l6). I think Cond3d is correct. Can I write the last layer with Conv1d?

ngpu=1
nz=11
ngf=8

class Generator(nn.Module):
    def __init__(self,ngpu,nz,ngf):
        super(Generator, self).__init__()
        self.ngpu=ngpu
        self.nz=nz
        self.ngf=ngf
   
        ## ---1x11x1x1x1
        self.l1= nn.Sequential(
            nn.ConvTranspose3d(self.nz+1, self.ngf * 6, 3, 1, 0, bias=True),
            nn.BatchNorm3d(self.ngf * 6),
            nn.ReLU(),)
        ##---48x3x3x3
        self.l2=nn.Sequential(nn.ConvTranspose3d(self.ngf * 6, self.ngf * 4, 3, 2, 0, bias=True),
            nn.BatchNorm3d(self.ngf * 4),
            nn.ReLU(),)
        ## ---32x7x7x7
        self.l3=nn.Sequential(nn.ConvTranspose3d( self.ngf * 4, self.ngf * 2, 3, 1, 0, bias=True),
            nn.BatchNorm3d(self.ngf * 2),
             nn.ReLU(),)
        ## ----16x9x9x9
        self.l4=nn.Sequential(nn.ConvTranspose3d( self.ngf*2, self.ngf*2, 3, 1, 0, bias=True),nn.BatchNorm3d(self.ngf * 2),
             nn.ReLU(),)
        ## ----16x9x9x9
        self.l44=nn.Sequential(nn.Conv3d(ngf * 2, self.ngf * 2, 3, 1, 1, bias=True),nn.BatchNorm3d(ngf * 2),nn.ReLU())
        ## ---16x11x11x11
        self.l5=nn.Sequential(nn.ConvTranspose3d( self.ngf*2, self.ngf*2, 3, 1, 0, bias=True),nn.BatchNorm3d(self.ngf * 2),
             nn.ReLU(),)
        ## ---16x11x11x11
        self.l55=nn.Sequential(nn.Conv3d(ngf * 2, self.ngf*2, 3, 1, 1, bias=True),nn.BatchNorm3d(ngf * 2),nn.ReLU())
        ## ---1x11x11x11
        self.l6=nn.Sequential(nn.Conv3d( self.ngf*2, 1, 3, 1, 0, bias=True),nn.Sigmoid())

    def forward(self, input,Labels):

        Labels=Labels.unsqueeze(1).unsqueeze(2).unsqueeze(3).unsqueeze(4)
        InputCat=torch.cat((Labels,input),1)
        out1=self.l1(InputCat)
        out2=self.l2(out1)
        out3=self.l3(out2)
        out4=self.l4(out3)    
        out5=self.l44(out4)
        ## ---- add layers---
        out6=out5+out4
        out7=self.l5(out6)
        out8=self.l55(out7)
        ## ---- add layers---
        out9=out8+out7
        out_total=self.l6(out9)
  
        return out_total

batchsize=10
Noise=torch.randn(batchsize,nz,1,1,1)
Condition=torch.ones(size=(batchsize,))

Gen=Generator(ngpu,nz,ngf)
Out=Gen(Noise,Condition)

The difference between nn.Conv1d and nn.Conv3d is the expected shape of the input (and thus also the kernel shape etc.).
nn.Conv1d would be used for “temporal” signals in the shape [batch_size, channels, seq_len], while nn.Conv3d is used for “volumetric” signals in the shape [batch_size, channels, depth, height, width].
Based on the comments from your code you are starting with a 5D tensor, which suddenly gets reduced to a 4D one, so I would assume that either the comments are wrong or the model is not working at the moment.

1 Like

you mean that you get error from the code? would you please tell me what do you mean by it is not working?
I apply it and did not get any error. before layer6 (self.l6) bachsize=10, channel=16, depth=11, height=11, width=11, after that I get (10,1,11,11,11) as output.
My aim is to generate 3d patches as 10x1x11x11x11
input is a noise 10x11x1x1x1 which will be concatenated with the Labels and will be 10x12x1x1x1 as the first layer input (1x11x1x1x1 is wrong)