I am trying to fix some specific elements in the grad_output tensor as below:
_IDX_LIST = [2, 4, 6, 9]
class GradZero(InplaceFunction):
@staticmethod
def forward(ctx, input):
ctx.inplace = False
return input
@staticmethod
def backward(ctx, grad_output):
with torch.no_grad():
grad_input = grad_output.detach().clone()
grad_input = fix_into_zero(grad_input)
return grad_input, None, None
def fix_into_zero(x, idx_list=_IDX_LIST):
x[idx_list] = 0.0
x[idx_list].requires_grad = False
return x
And then I found that for for param in test_model.parameters(): print("{}'s requires_grad is {}".format(param.shape, param.requires_grad))
, all the param stay as requires_grad = True
. Maybe it is as the element itself has no require_grad
attribute?
In such case how could I turn off the autograd for some specific elements as above snippet? I also referred to this post (How can I disable all layers gradient expect the last layer in Pytorch?) but cannot figure out.