Diego
(Diego)
May 7, 2019, 8:15am
1
Hello I have the following code:
deltac = maxc - minc
s = deltac / maxc
This section does not pass the gradient check, claiming an inplace operation on the 2nd line. I solved it by replacing it with deltac.clone() / maxc
. However I am curious as to where is the inplace operation here? Thanks in advance for any help
>>> c = torch.rand(5,10)
>>> c.requires_grad = True
>>> minc, _ = c.min(1)
>>> maxc, _ = c.max(1)
>>> deltac = maxc - minc
>>> s = deltac / maxc
>>> s.sum().backward()
>>> s
tensor([0.9689, 0.8587, 0.8610, 0.9500, 0.9923], grad_fn=<DivBackward0>)
>>> c.grad.max()
tensor(0.1661)
I canβt reproduce your problem.
Diego
(Diego)
May 7, 2019, 8:59am
3
Thanks for the swift reply I will post the entire function that produces the problem, this will provide some more context.
def rgb_to_hsv(image):
torch.autograd.set_detect_anomaly(True)
r = image[..., 0, :, :]
g = image[..., 1, :, :]
b = image[..., 2, :, :]
maxc = image.max(-3)[0]
minc = image.min(-3)[0]
v = maxc # brightness
deltac = maxc - minc
s = deltac.clone() / v # saturation
deltac[deltac == 0] = 1 # avoid division by zero
rc = (maxc - r) / deltac
gc = (maxc - g) / deltac
bc = (maxc - b) / deltac
h = 4.0 + gc - rc
h[g == maxc] = 2.0 + rc[g == maxc] - bc[g == maxc]
h[r == maxc] = bc[r == maxc] - gc[r == maxc]
h[minc == maxc] = 0.0
h = (h / 6.0) % 1.0
return torch.stack([h, s, v], dim=-3)
The code for checking the gradients:
from torch.autograd import gradcheck
def test_gradcheck(self):
data = torch.tensor([[[[21., 22.],
[22., 22.]],
[[13., 14.],
[14., 14.]],
[[8., 8.],
[8., 8.]]]]) # 3x2x2
data = utils.tensor_to_gradcheck_var(data) # to var
assert gradcheck(image.RgbToHsV(), (data,),
raise_exception=True)
This is an in-place operation. Try to replace it with:
deltac = torch.where(deltac == 0, 1, deltac)
2 Likes
Diego
(Diego)
May 7, 2019, 10:01am
5
Thank you: I had to change your line to
deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)
But it works now. Is there a cleaner way to do this, similar to np.where(arr==0, 1, a)
?