In my network, one computation only applies partially to a tensor/parameter. For example, square the 1st element of parameter w.
I used two methods. The forward computations of these two methods are correct, but the gradients are different. Method 2 should be correct since it is just regular computation. Hence, Method 1 is incorrect. I reckon this is because the computation node is NOT properly connected when torch.tensor()
is used in method 1?
I want to know why they are different and how the computational graph works when a function applies partially to a tensor.
Method 1
# square the 1st element of parameter w with `torch.tensor()`
self.v = torch.tensor((self.w[0]**2, self.w[1]), requires_grad=True)
Method 2
# square the 1st element of parameter w with a mask
self.mask = torch.tensor([True, False])
self.v = self.w**2 * self.mask + self.w * torch.logical_not(self.mask)
Recreating a new tensor detaches the passed tensors from the computation graph and creates a new leaf tensor without a gradient history. Operations on self.v
will thus not be backpropagated to self.w
.
Thanks a lot. That makes sense. That is why I use method 2. However, as you can see, Method 2 involves more computation.
It might not differ much in the example mentioned here. However, in my model, the function is not simple as square and it applies to a larger proportion of a high-dimensional parameter.
Is there a more efficient way to do this?
You could torch.stack
the tensor parts as seen here:
w = torch.randn(2, requires_grad=True)
# square the 1st element of parameter w with a mask
mask = torch.tensor([True, False])
v = w**2 * mask + w * torch.logical_not(mask)
out = torch.stack((w[0]**2, w[1]))
print((out == v).all())
# tensor(True)
w.mean().backward()
print(w.grad)
# tensor([0.5000, 0.5000])
Thanks. Both of your replies solve my questions.
Since the first one is related directly to the title (Incorrect grad …), I will mark the first one as the Solution.