Masking without breaking graph


Say, I have a tensor a : tensor([0.6511, 0.6983, 0.2852, 0.2281], requires_grad=True). Now, I convert it to 1 if value > 0.5 else change it to 0. Now, the new tensor is tensor([1, 1, 0, 0], requires_grad=True). I use this tensor to perform a masking operation. However, when I backpropagate, the gradients of a become 0.

Another approach might be to use a.detach(), however, now a is not included in the computation graph, and gradients are not computed.

How do I approach this without breaking the graph, and without computing grad value as 0.

Small Snippet:

a = regularizer(torch.mul(per_step_task_embedding, layer_out)) # tensor([0.6511, 0.6983, 0.2852, 0.2281], requires_grad=True)
b = a.round() # # tensor([1, 1, 0, 0], requires_grad=True)
for i, (p, g) in enumerate(zip(params, grad)):
    p.update = - g * b[i] # masking

Thank you!

Ok, so I have made a hack around this to make it work. Writing what worked for me here, if anyone comes across a similar problem.
Firstly, the regularizer I used had Sigmoid() in the end. I removed it.
So, now I have:

layer_pred = regularizer(torch.mul(per_step_task_embedding, layer_out)) 
sigm = torch.nn.Sigmoid()
temp_pred = layer_pred.detach()
cond = sigm(temp_pred).round() # [0 or 1] = # replace data in layer_pred without changing graph
masking_module(layer_pred, params, gradients) = # replace data with original data again without changing graph

Just replying to let you know that you shouldn’t use .data field as it’s only there for backwards compatibility and can lead to incorrect calculations. You can read more here, try using .copy_() instead.

Thanks a lot, @AlphaBetaGamma96 !
Replacing .data with .copy_()

temp_pred = layer_pred.clone()
cond = sigm(temp_pred).round() # [0 or 1] 
masking_module(layer_pred, params, gradients)