Flatten layer of PyTorch build by sequential container

I am trying to build a cnn by sequential container of PyTorch, my problem is I cannot figure out how to flatten the layer.

main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2)) 
main.add_module('flatten', make_it_flatten)

What should I put in the “make_it_flatten”? I tried to flatten the main but it do not work, main do not exist something call view

main = main.view(-1, 16*3*3)

1 Like

In this case we would prefer to write the module with a class, and let nn.Sequential only for very simple functions.
But if you definitely want to flatten your result inside a Sequential, you could define a module such as

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

and use Flatten in your model


Thanks, but this do not wok. it give me error messages

RuntimeError: size mismatch, m1: [4 x 784], m2: [144 x 120] at /b/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:1237

If I change Flatten to

class Flatten(nn.Module):
def forward(self, x):
x = x.view(-1, 1633)
return x

It give me error messages

RuntimeError: size ‘[-1 x 144]’ is invalid for input of with 3136 elements at /b/wheel/pytorch-src/torch/lib/TH/THStorage.c:55

gist of codes

class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        main = nn.Sequential()
        self._conv_block(main, 'conv_0', 3, 6, 5)
        main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
        self._conv_block(main, 'conv_1', 6, 16, 3)        
        main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2))
        main.add_module('flatten', Flatten())
        self._linear_block(main, 'linear_0', 16*3*3, 120)        
        self._linear_block(main, 'linear_1', 120, 84)
        main.add_module('linear_2-84-10.linear', nn.Linear(84, 10))
        self._main = main
    def forward(self, x):
        for module in self._modules.values():
            x = module(x)
        return x
    def _conv_block(self, main, name, inp_filter_size, out_filter_size, kernal_size):        
        main.add_module('{}-{}.{}.conv'.format(name, inp_filter_size, out_filter_size), 
                        nn.Conv2d(inp_filter_size, out_filter_size, kernal_size, 1, 1))
        main.add_module('{}-{}.batchnorm'.format(name, out_filter_size), nn.BatchNorm2d(out_filter_size))
        main.add_module('{}-{}.relu'.format(name, out_filter_size), nn.ReLU())                
    def _linear_block(self, main, name, inp_filter_size, out_filter_size):
        main.add_module('{}-{}.{}.linear'.format(name, inp_filter_size, out_filter_size), 
                        nn.Linear(inp_filter_size, out_filter_size))
        main.add_module('{}-{}'.format(name, out_filter_size), nn.ReLU())

Complete codes are place at pastebin, 90 lines

I guess I would follow this advices, I pick sequential container because I do not know how to group several layer together(conv->batch->relu, define a residual block etc), until I find this example.

The bug is probably in

You have the wrong size for the linear block, it should probably not be 16*3*3, but something else.

Also, you are overcomplicating the definition of your model. If your model is just a sequential, you can construct it with an OrderedDict, so that you don’t need to create a Net class for it.

Something like

def _conv_block(l, name, inp_filter_size, out_filter_size, kernal_size):
    l['{}-{}.{}.conv'.format(name, inp_filter_size, out_filter_size)] =  
                    nn.Conv2d(inp_filter_size, out_filter_size, kernal_size, 1, 1))
    l['{}-{}.batchnorm'.format(name, out_filter_size)] = nn.BatchNorm2d(out_filter_size))
    l['{}-{}.relu'.format(name, out_filter_size)] = nn.ReLU())

def _linear_block(l, name, inp_filter_size, out_filter_size):
        l[('{}-{}.{}.linear'.format(name, inp_filter_size, out_filter_size)] = 
                        nn.Linear(inp_filter_size, out_filter_size))
        l[('{}-{}'.format(name, out_filter_size)] = nn.ReLU())

l = OrderedDict()
_conv_block(l, 'conv_0', 3, 6, 5)
l['max_pool_0_2_2'] = nn.MaxPool2d(2,2)

model = nn.Sequential(l)

And that can still be cleaned up a bit


You are correct, it should be 16 * 6 * 6.

Thanks for the tips of OrderedDict :slight_smile:

Do PyTorch have any plan to provide layers without a need to specify input size?

I don’t think it is in the plans to provide any functionality like that, because it depends on the input size, and pytorch has dynamic graphs so you don’t compile your graph beforehand.
But, there are workarounds for that.
Check for example Inferring shape via flatten operator

1 Like

Has this been included in official modules?

It could be useful because you decouple your module from the forward method. I mean, if I wanna convert FC to conv layers on a pre-trained model, and the forward method includes the x.view call, I can’t, right?

With the flatten operator inside the model, I can just unplug it and put a conv before. Does this make sense or am I looking at it from the wrong side?

so the flatten module is not planned to be included in the pytorch library then?

Doesn’t this solution work for any shape?

class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

Hi Everyone,

In torch 1.2.0 - Flatten module is included. :grinning: I attached small code snippet.

import torch.nn as nn
import torch

f = nn.Flatten()
x = torch.randn((2, 2, 3))
print('Before flatten shape : ', x.shape)
print('After flatten shape : ', f(x).shape)



Thank you @fmassa. This is awesome as I can keep wrapping preprocessing ( important but not grad affecting tasks ) into these classes

Thank you so much. This has saved me from a torturous rabbit-hole of debugging. I have been struggling for 3 days to fix this shape problem. I appreciate your answer.

can anyone of you please explain me the view() function and what are the parameters it takes please>>
very urgent

Answered here.