Add reshape nn.Module

Looking at the torchvision.models.resnet34 this is forward:

class ResNet(nn.Module):

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)

        return x

This means resnet34 could have been

class ResNet(nn.Sequential):

If it was not for the reshape. Then manipulating it would have been more straightforward and we would not need to treat it differently. resnet34 is just an example, but in general it would be nice to also have a simple reshape nn.module and use it instead of re-implemeting forward.

We have a discussion here, so maybe you would want to explain your use case there, too. :slight_smile:

Thanks. I added to that thread. It is a simple module, but if it is part of torch.nn the likelyhood of simppler implementation would be higher.

Relying on forward is good if you own the code, or if there is no standard layers you want to manipulate. But if for example you want to take official resnet34 and do something to it, suddenly you realize that some part of the model is in forward that could have been in __init__ as a module. Manipulating a dict is simpler that code.

I would like to have Reshape as as nn.Module because then I can use it in nn.Sequential easily.

Would the implementation in the GitHub issue work for you or are you seeing any issues with it?

It definitely will work. Some people might prefer to have it as part of the official PyTorch API.