How to get output from intermediate layer in nn.Sequential and reshape it before passing to next layer?

Hello!

I have a pretrained model whose weights I wish to use. It has the last sequential layer as follows:

    self.output_layer = Sequential(BatchNorm2d(512), 
                                   Dropout(drop_ratio),
                                   Flatten(),
                                   Linear(512 * 7 * 7, 512),
                                   BatchNorm1d(512))

The linear layer gives me output of shape 1 x 512. But in place of BatchNorm1d, I need to use Batchnorm2d as it is supported for conversion on the edge device on which I am working, and Batchnorm1d is not supported. So, I can change Batchnorm1d to Batchnorm2d and load the weights of pretrained model. But, Batchnorm2d expects a 4D tensor and Lienar layer will give me 2D tensor. So how to solve this? I suppose we can not add a reshape module in between Linear and Batchnorm2d as while loading the weights it will throw as error. Can we get output after Linear and then reshape it using a custom reshape module and then put it to BatchNorm2d again in the same sequential block? How to do it if its doable? Or is there any other workaround for this?

Thanks!

You could try to insert a custom layer to reshape the activation after loading the state_dict:

# create reference
output_layer = nn.Sequential(nn.BatchNorm2d(512), 
                             nn.Dropout(0.5),
                             nn.Flatten(),
                             nn.Linear(512 * 7 * 7, 512),
                             nn.BatchNorm1d(512))

sd = copy.deepcopy(output_layer.state_dict())

# restore
layer = nn.Sequential(nn.BatchNorm2d(512), 
                      nn.Dropout(0.5),
                      nn.Flatten(),
                      nn.Linear(512 * 7 * 7, 512),
                      nn.BatchNorm2d(512))

layer.load_state_dict(sd)
# <All keys matched successfully>

x = torch.randn(2, 512, 7, 7)
out = layer(x)
# ValueError: expected 4D input (got 2D input)

class Reshape(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape
        
    def forward(self, x):
        x = x.reshape(self.shape)
        return x

layer.insert(-1, Reshape((-1, 512, 1, 1)))
out = layer(x)
print(out.shape)
# torch.Size([2, 512, 1, 1])
1 Like

Hi @ptrblck ,

Thank you for the apt solution. I did some other work around which is little lengthy and naive I feel.

Created a new ReshapeModule

class ReshapeModule(Module):
def forward(self, x):
x = x.view(1, 512, 1, 1)
return x

In model’s forward method,

x = self.output_layer:-1
x = self.to_2d_bn(x)
x = self.output_layer-1:

where,

self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 7 * 7, 512),
BatchNorm2d(512))

Thanks,
Sourabh

Your approach is also totally fine.

1 Like