I am coding about perturbation with torch and I want to get a mask with the same shape of input

mask = model1(mask_initial). #to get a mask tensor with the same shape of input
threshold = torch.max(mask)*radio
#to perturbate input
after_input['id'] = torch.where(mask > threshold, initial_input['id'],perturbation_input['id'])
# initial_input['id'],perturbation_input['id'] is torch.int32
after_input['value'] = torch.where(mask > threshold, initial_input['value'],perturbation_input['value'])
# initial_input['value'],perturbation_input['value'] is torch.float
#to get loss
score1 = model2(initial_input)
score2 = model2(after_input)
loss = score1 - score2

I expect this loss can optimize the parameters of model1, however the mask.grad is none as well as the grad of the parameters of model1.How can I make it? Can torch.where propagate grad to its condition?

torch.where() does not propagate gradients to its condition.

The reason for this is those gradients – although well defined – aren’t useful.
When mask < threshold, the gradient is zero. When mask > threshold,
the gradient is also zero. When mask is exactly equal to threshold, the
gradient is (usually) undefined. So the gradient is almost always zero, and
there is no point in backpropagating it.

Assuming that you can construct a function() that gives you a useful
surrogate (or approximation to the) gradient of loss with respect to mask,
then this could work.

Note, however, that if your surrogate gradient is zero almost everywhere,
it won’t be useful and won’t do you any good (even though it might be
a very good approximation to the actual gradient).

If you use this approach, you don’t want to overwrite mask.grad. Instead
call mask.backward (function (loss)). You can pass to .backward()
the gradient (or a surrogate) of your downstream scalar loss function with
respect to mask (or whatever you’re calling .backward() on), and autograd
will complete the backpropagation for you.

But this will only work if your surrogate gradient – function (loss) – is
actually useful and approximates the gradient of some sort of “smoothed
thresholding” version of the not-usefully-differentiable thresholding you
apply to mask.