I am using a boolean mask in a network that perform some attention mechanisms. The boolean mask is updated in-place in some loop which causes autograd to crash. Here is an example of the kind of stuff that I do and that does NOT work:
import torch
import torch.nn.functional as F
d = 4
x = torch.rand(d, requires_grad=True)
mask = torch.zeros(d).bool()
# iteration 1
label = 1
mask[0] = True
y = x.masked_fill( mask , float('-inf') )
p = F.softmax(y,dim=0)
loss = - torch.log( p[label] )
# compute by hand grad of the loss at iteration 1
indicator = torch.zeros(d)
indicator[label] = 1
gradloss1 = p-indicator
# iteration 2
label = 1
mask[3] = True
y = x.masked_fill( mask , float('-inf') )
p = F.softmax(y,dim=0)
loss += - torch.log( p[label] )
# compute by hand grad of the loss at iteration 2
indicator = torch.zeros(d)
indicator[label] = 1
gradloss2 = p-indicator
# backprop
loss.backward()
# check we got correct gradient
print(x.grad)
print(gradloss1+gradloss2)
This can be fixed by replacing the above code by the following code, in which the boolean mask is not updated in-place:
import torch
import torch.nn.functional as F
d = 4
x = torch.rand(d, requires_grad=True)
# iteration 1
label = 1
mask = torch.zeros(d).bool()
mask[0] = True
y = x.masked_fill( mask , float('-inf') )
p = F.softmax(y,dim=0)
loss = - torch.log( p[label] )
# compute by hand grad of the loss at iteration 1
indicator = torch.zeros(d)
indicator[label] = 1
gradloss1 = p-indicator
# iteration 2
label = 1
mask = torch.zeros(d).bool()
mask[0] = True
mask[3] = True
y = x.masked_fill( mask , float('-inf') )
p = F.softmax(y,dim=0)
loss += - torch.log( p[label] )
# compute by hand grad of the loss at iteration 2
indicator = torch.zeros(d)
indicator[label] = 1
gradloss2 = p-indicator
# backprop
loss.backward()
# check we got correct gradient
print(x.grad)
print(gradloss1+gradloss2)
I am a little confused because the boolean mask has always a requires_grad flag set to False. Why does it matter if the mask is updated in-place or not? Can someone provide a little more explanations on how autograd work in this case? Thanks!