Mean and Sum reduction of the same loss does not produce vectors with angle 0

Hi,

So I have two batches of 3 images in (3x256x256) and want to calculate the gradient wrt image-wise l1 distances. The first batch f_x is actually the output of some larger UNet wrt to an input x and I want to backpropagate through it to get the gradient of the l1 distance wrt x.

So in code:

%x and f_x have shape (bs, 3, 256, 256)
f_x = f(x) 
d = torch.nn.functional.l1_loss(f_x , init_image, reduction='sum')
grad = torch.autograd.grad(d, x, retain_graph=True)[0]

Importantly, my gradient is normalized before I do anything with it so I do not care about scale but only orientation. I noticed some weird behaviours with the gradient so instead, I tried to do:

%x and f_x have shape (bs, 3, 256, 256)
f_x = f(x) 
d = torch.nn.functional.l1_loss(f_x , init_image, reduction='mean')
grad = torch.autograd.grad(d, x, retain_graph=True)[0]

Like expected the value of d in the first case is equal to the value in the second case times bs * 3 * 256 * 256. So I expected the cosine similarity between the individual gradients in the two grad matrices to be 1, after all differentiation is linear. So I was actually very surprised when the cosine turned out to be approximately 0.6. Like I said, I do not need scaling information so working with mean is fine for me, however I would like to understand which of the two calculations I can trust more and what causes this issue.

Thank’s for your help

It seems I get the expected results using your pseudo-code snippet:

bs = 2
x = torch.randn(bs, 3, 256, 256, requires_grad=True)
f_x = x * 2
init_image = torch.randn_like(x)

d = torch.nn.functional.l1_loss(f_x , init_image, reduction='sum')
grad1 = torch.autograd.grad(d, x, retain_graph=True)[0]


d = torch.nn.functional.l1_loss(f_x , init_image, reduction='mean')
grad2 = torch.autograd.grad(d, x, retain_graph=True)[0]

print(((grad1 / x.nelement()) - grad2).abs().max())
# tensor(0.)

out = nn.CosineSimilarity()(grad1, grad2)
torch.allclose(out, torch.ones_like(out))
# True
1 Like