How gradient flows across torch.where(), tensor.min()?


  1. When I will do: e1.backward(), would the gradient flow properly back to the parameters of the network?
  2. How does this torch.where() statement varies from d1[c2 <= 0] = 0.0 , or if-else statements (if implementable here) ?
  3. Kindly suggest some good implementations of the mask, threshold operations allowing gradient flow across them?

Please see the attached image for the computation flow (roughly).

d1 is the modified c1 based on the condition or mask created by c2

a: is a tensor of shape [16,3,256,256] # rgb image batch
c1, c2: single-channel tensors [16,1,256,256]

The chain a-b2-c2 includes operations like torch.round(), modulo, view, squeeze, stack, min, clone, etc.

I want to update the parameters of the ‘Network’ based on the gradient calculated from e1, considering the fact that both c1, c2 depends upon a (output of network).


c2 = b2.min(dim = 0, keepdim = True)[0]
d1 = torch.where(c2 <= 0, torch.tensor(0,device='cuda:0',dtype = torch.float32,requires_grad=True), c1)


The autograd will compute the gradient of the op.
So a good mental model is how much the output varies if you change the input by an epsilon.
Here you see that if c2 moves by eps, nothing changes (except close to 0 but that is a special case). So the gradient flowing back along c2 will be 0.
For c1, all the values that were overwritten as well, if you change them, it does not change the result. So the gradient is going to be 0 for them
For the values that were kept as-is, then the gradient will be exactly the same as the one flowing from e1.

All the indexing/torch.where, etc will follow this logic and will give you the same gradient.

1 Like