What is nn.Identity() used for?

Your code snippet shows that nn.Identity will just return its input, but doesn’t show that it’s a view.
If you thus manipulate the input inplace, the output of nn.Identity will also be changed:

a = torch.arange(4.)
m = nn.Identity()
input_identity = m(a)

print(a)
> tensor([0., 1., 2., 3.])

print(input_identity)
> tensor([0., 1., 2., 3.])

# manipulate inplace
a[0] = 2.
print(a)
> tensor([2., 1., 2., 3.])

print(input_identity)
> tensor([2., 1., 2., 3.])
4 Likes