Detach() in forward pass

Suppose I have myNetwork class defined as:

class myNetwork(nn.Module):
    def __init__(self, a, b, embedding_dim):
        super(modeler, self).__init__()
        self.embed1 = nn.Embedding(a, embedding_dim)
        self.embed2 = nn.Embedding(b, embedding_dim)

        
    def forward(self, idx1, idx2):
        
        embeds1 = self.embed1(idx1)
        embeds2 = self.embed2(idx2)

        tmp = embeds1 * embeds2 

        tmp1 = embeds1 + tmp

        output = torch.sum((tmp1 - embeds2)**2,1)

        return output

In my training loop, I forward idx1 and idx2 to get output, which is a Variable.

As shown, tmp and tmp1 are intermediate Variables whose gradients are not necessary, because I can write the following instead by substituting tmp and tmp1 as follows:

output = torch.sum(((embeds1 + (embeds1 * embeds2)) - embeds2)**2,1)

My question is, since tmp and tmp1 do not require gradients, is it more efficient to detach tmp and tmp1 in forward() as below?

....
tmp = tmp.detach()
tmp1 = tmp1.detach()
....

I am new to pytorch and hope I don’t bother you.

Thanks

Hi,

This is not what detach() is for.
You can see detach as being a breakpoint so that no gradient will flow above this point.
In your case, if you detach tmp and tmp1, no gradients will be propagated to embeds1 and embeds2.

In pytorch, the gradients are only computed for the Variables that are created by the user with the parameter requires_grad=True. So no gradient will be computed for the temporary Variables by default.
When using nn, keep in mind that creating an nn.Parameter is the same as a Variable with requires_grad=True.

Thanks for your reply. I understood your reply.

But when I look at tmp.requires_grad in the forward() function, it returns True. So you are saying that even though it says True, gradients are not computed. Am I right?

If a Variable has var.requires_grad=True, that means that it has been computed from a Variable created with the argument requires_grad=True (called leaf Variables).
And thus, to compute the gradients for all the leaf Variables, we will need to have some gradients flowing back through this Variable.

Ok. It is clear now! Thanks a lot.