ResNet unexpected output

I’m working on a project and I need the output of a truncated ResNet. Taking ResNet18 as an example, if I do the following changes

import torch
import torch.nn as nn
import torchvision.models as models

resnet = models.resnet18(pretrained=False)
resnet.avgpool = nn.Identity()
resnet.fc = nn.Identity()

The output gets flattened and I don’t know why. The curious thing is that if I forward a (16, 3, 256, 256) tensor through it, I get an output of size (16, 32768), while if I manually forward the input layer by layer, I get the expected output (16, 512, 8, 8). It seems like a bug, but I’m not sure. Can anyone take a look into it?

1 Like

It’s expected, as ResNet uses a functional call to torch.flatten in the forward pass in this line of code.
You could either derive a custom ResNet class and reimplement the forward method or use forward hooks to get the desired activation.

Alternatively, you could also create an nn.Sequential container and add the layers in the same order as they are called in the original forward method.

2 Likes

That’s true, I should have checked the ResNet code beforehand.
And yes, I did work it around by calling the layers in a nn.Sequential container to avoid the flattened tensor.
Thanks ptrblck.

1 Like