The problem of replacing the resnet layer with nn.Identity in torchvision

import torch
import torchvision
from torch import nn
from torchsummary import summary

model = torchvision.models.resnet18(pretrained=False)

model.avgpool = nn.Identity()
model.fc = nn.Identity()
x = torch.randn(size=(2,3,128,128))
y = model(x)


the output is
torch.Size([2, 3, 128, 128])
torch.Size([2, 8192])

The shape of the last few layers(using torchsummary ):
BatchNorm2d-64 [2, 512, 4, 4]
ReLU-65 [2, 512, 4, 4]
BasicBlock-66 [2, 512, 4, 4]
Identity-67 [2, 512, 4, 4]
Identity-68 [2, 8192]

I don’t understand why the last identity layer flattens the image

The nn.Identity layer is not flattening the activation, but this torch.flatten call defined in the forward method via the functional API.

Thank you for your answer, can you tell me how to cancel this flatten operation

You could create a custom model and override the forward method with your own.

ok, i see, thanks again