Boolean masks and in-place operation with autograd

I am using a boolean mask in a network that perform some attention mechanisms. The boolean mask is updated in-place in some loop which causes autograd to crash. Here is an example of the kind of stuff that I do and that does NOT work:

import torch
import torch.nn.functional as F 

d = 4
x = torch.rand(d, requires_grad=True)
mask = torch.zeros(d).bool()

# iteration 1
label = 1
mask[0] = True
y = x.masked_fill( mask , float('-inf') )
p = F.softmax(y,dim=0)
loss = - torch.log( p[label] )

# compute by hand grad of the loss at iteration 1
indicator = torch.zeros(d)
indicator[label] = 1
gradloss1 = p-indicator

# iteration 2
label = 1
mask[3] = True
y = x.masked_fill( mask , float('-inf') )
p = F.softmax(y,dim=0)
loss += - torch.log( p[label] )

# compute by hand grad of the loss at iteration 2
indicator = torch.zeros(d)
indicator[label] = 1
gradloss2 = p-indicator

# backprop
loss.backward()

# check we got correct gradient
print(x.grad)
print(gradloss1+gradloss2)

This can be fixed by replacing the above code by the following code, in which the boolean mask is not updated in-place:

import torch
import torch.nn.functional as F 

d = 4
x = torch.rand(d, requires_grad=True)

# iteration 1
label = 1
mask = torch.zeros(d).bool()
mask[0] = True
y = x.masked_fill( mask , float('-inf') )
p = F.softmax(y,dim=0)
loss = - torch.log( p[label] )

# compute by hand grad of the loss at iteration 1
indicator = torch.zeros(d)
indicator[label] = 1
gradloss1 = p-indicator

# iteration 2
label = 1
mask = torch.zeros(d).bool()
mask[0] = True
mask[3] = True
y = x.masked_fill( mask , float('-inf') )
p = F.softmax(y,dim=0)
loss += - torch.log( p[label] )

# compute by hand grad of the loss at iteration 2
indicator = torch.zeros(d)
indicator[label] = 1
gradloss2 = p-indicator

# backprop
loss.backward()

# check we got correct gradient
print(x.grad)
print(gradloss1+gradloss2)

I am a little confused because the boolean mask has always a requires_grad flag set to False. Why does it matter if the mask is updated in-place or not? Can someone provide a little more explanations on how autograd work in this case? Thanks!

The thing is that the backward operation of masked_fill() needs to know where the input was written into. So it needs the value of the mask. So if you modify it inplace, the autograd engine will fail to compute gradients.

Thanks, that makes complete sense!