I have this function to block the gradient based on the mask
class blocked_grad(torch.autograd.Function):
@staticmethod
def forward(ctx, x, mask):
ctx.save_for_backward(x, mask)
return x
@staticmethod
def backward(ctx, grad_output):
x, mask = ctx.saved_tensors
return grad_output * mask, mask * 0.0
If the mask values are 0, 1 then I am wondering how I could partially mask gradient with some scaling parameter?