What is nn.Identity() used for?

Skip connections would usually a add activations instead of concatenating them as seen in the resnet example.

An often used use case for nn.Identity would be to get the “features” of a pretrained model instead of the class logits.
Here is an example:

model = models.resnet18()
# replace last linar layer with nn.Identity
model.fc = nn.Identity()

# get features for input
x = torch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)
> torch.Size([1, 512])
5 Likes