What is nn.Identity() used for?

I am trying to understand usage of nn.Identity , but unable to understand what pytorch document referring with following example nn.Identity

seed =4
m = nn.Identity()
input = torch.randn(4,4)
output = m(input)
print(output)
print(output.size())

What nn.Identity is doing here ?

tensor([[-0.2381, -0.0488,  1.2883, -2.0334],
        [-0.7297, -0.8721,  0.7086,  0.3899],
        [ 0.6550,  1.4832,  0.3744, -1.2825],
        [ 2.3327,  0.5004, -0.8785,  0.5100]])
torch.Size([4, 4])

Without identiy

seed =4
input = torch.randn(4,4)
output = input
print(output)
print(output.size())
tensor([[-0.0255, -0.5771, -0.5268, -0.6201],
        [ 1.0814,  0.4274,  0.6822,  0.2369],
        [ 0.0251, -1.1992, -0.9557,  0.9640],
        [-0.2789, -2.4326, -2.4736,  0.2777]])
torch.Size([4, 4])

seed =4
m = nn.Identity()
input = torch.randn(4,4)
output = m(input)
print(f'input {input}')
print(f'output {output}')

input tensor([[-1.3704, -0.4062, -0.5499,  1.8572],
        [ 0.3674,  0.4486, -0.3362, -1.4035],
        [ 0.1322,  0.2116, -0.1157,  1.0715],
        [ 0.8215,  0.1549,  0.8935, -0.1930]])
output tensor([[-1.3704, -0.4062, -0.5499,  1.8572],
        [ 0.3674,  0.4486, -0.3362, -1.4035],
        [ 0.1322,  0.2116, -0.1157,  1.0715],
        [ 0.8215,  0.1549,  0.8935, -0.1930]])

@TheOraware
please, see this accepted answer in stack overflow, here

@pain i read this and understand it , but in a very simple lay man example as above i posted what is it doing with and without nn.Identity ? What Identity playing role in above example

@ptrblck has a great experince. I hope he can help.

@pain i think i got it what does it do is it remains keep intact of original input shape , as NN shapes change over many different layer , we can keep original input layer shape as a placeholder and use this to add on your other layer’s output for skip connection

a = torch.arange(4.)
print(f' "a" is {a} and its shape is {a.shape}')
m = nn.Identity()
input_identity = m(a)
# change shape of a
a= torch.reshape(a, (2, 2))
print(f' "a" shape is now changed {a.shape}')
print(f' due to identity it remains has same shape as time of input {input_identity.shape}')
"a" is tensor([0., 1., 2., 3.]) and its shape is torch.Size([4])
"a" shape is now changed torch.Size([2, 2])
due to identity it remains has same shape as time of input torch.Size([4])

@ptrblck please acknowledge my understanding is correct or not , if not then please help me to understand it

The nn.Identity module will just return the input without any manipulation and can be used to e.g. replace other layers. The source code can be found here.

3 Likes

@ptrblck can you please see my last post ? i tried to represent in code

Your code snippet shows that nn.Identity will just return its input, but doesn’t show that it’s a view.
If you thus manipulate the input inplace, the output of nn.Identity will also be changed:

a = torch.arange(4.)
m = nn.Identity()
input_identity = m(a)

print(a)
> tensor([0., 1., 2., 3.])

print(input_identity)
> tensor([0., 1., 2., 3.])

# manipulate inplace
a[0] = 2.
print(a)
> tensor([2., 1., 2., 3.])

print(input_identity)
> tensor([2., 1., 2., 3.])
2 Likes

how can i use nn.Identity() where it just return the input without any manipulation? What is the point of using it when i can’t achieve original input without influence of any subsequent manipulation? We can achieve this using clone () as below, what is the point of nn.Identity()?

a = torch.arange(4.)
input_identity = a.clone().detach()
print(a)
print(input_identity)
tensor([0., 1., 2., 3.])
tensor([0., 1., 2., 3.])
a[0] = 2.
print(a)
print(input_identity)
tensor([2., 1., 2., 3.])
tensor([0., 1., 2., 3.])

As mentioned before, nn.Identity will just return the input without any clone usage or manipulation of the input. The input and output would thus be the same.

Yes, this “pass-through” layer can easily be written manually, which was also the reason feature requests were declined in the past.
However, since a lot of users were rewriting the same layer to e.g. replace specific modules inside a larger model, this layer was introduced.

2 Likes

@ptrblck thanks for reply - sorry i am not getting your last paragraph

“Yes, this “pass-through” layer can easily be written manually, which was also the reason feature requests were declined in the past.
However, since a lot of users were rewriting the same layer to e.g. replace specific modules inside a larger model, this layer was introduced.”

Could you please elaborate it more?

You might not want to use this layer as it’s not “doing anything” besides just returning the input.
However, there are use cases where users needed exactly this (e.g. to replace another layer) and were manually creating custom modules to do so and asked for the nn.Identity layer in the PyTorch nn backend. Since more and more users were depending on it, it was created.
However, as already said, this layer might not be interesting for you.

3 Likes

@ptrblck it means its best use is when we connect previous layer output to the output of new layer? For example in residual connection sometime we connect output of previous layer to output of new layer. If this is nn.Identity() is designed for then we can achieve the same as mentioned in my following code.

I think i need to read about nn.Identity() more , maybe it will take time to me to understand its underline mechanism

However , could you please confirm following code has skip connection res and res1 which i used to avoid gradient vanishing, i joined output of layer1 (res) and output of layer 2 (res1) with output of layer6 F.relu(torch.cat([x,res,res1],1)) .Is this is sort of way we join layers by skipping one to another in order to avoid gradient vanishing?

class multiNetA(nn.Module):
    def __init__(self, num_feature, num_class):
        super().__init__()

        self.lin1 = nn.Linear(num_feature,50)
        self.lin2 = nn.Linear(50, 30)
        self.lin6 = nn.Linear(30, 20)
        self.lin7 = nn.Linear(100, num_class)

        self.bn0 = nn.BatchNorm1d(num_feature)
        self.bn1 = nn.BatchNorm1d(50)
        self.bn2 = nn.BatchNorm1d(30)
        self.bn6 = nn.BatchNorm1d(20)
        self.bn7 = nn.BatchNorm1d(9)

    def forward(self, x):
        x = self.bn0(x)
        x = self.lin1(x) 
        x = self.bn1(x) 
        x = F.relu(x)
        res = x <---- skip connection
        
        
        x = self.lin2(x) 
        x = self.bn2(x) 
        x = F.relu(x)
        
        res1 = x <--- another skip connection
      

        x = self.lin6(x)
        x = self.bn6(x) 
        x = F.relu(torch.cat([x,res,res1],1)) <-- joined res and res1 (input) with output of layer 6
        x = self.lin7(x) # output layer
        
        return x

Skip connections would usually a add activations instead of concatenating them as seen in the resnet example.

An often used use case for nn.Identity would be to get the “features” of a pretrained model instead of the class logits.
Here is an example:

model = models.resnet18()
# replace last linar layer with nn.Identity
model.fc = nn.Identity()

# get features for input
x = torch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)
> torch.Size([1, 512])
2 Likes