Masking in Transformer decoders with -inf rather than 0s

Hi,

I have a question about implementation of masking in Transformer decoders. I understand that the purpose of masking is so we dont peek at future tokens in the target sequence.
I am trying to understand why the subsequent tokens that are masked are filled with -inf rather than 0s. https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L3427
I imagined that these positions would be filled with 0s so that when dot product is taken, then corresponding future positions would be 0, and so would the gradient. And then there would be no change in parameters in the corresponding future positions when we take a step along the gradient. With the -inf I think the gradients would be -inf/ill-defined? And the resulting solution would have nans in the subsequent positions.

Could you please help me understand the reasoning for filling with -inf?

Thanks!

I guess it is related to softmax function. Usually, you have softmax in transformer/bert architecture and -inf will vanish in that case.

we could carry following experiment to see what is going on while computing gradient, let us assume we have

l = nn.Linear(2, 2, bias=False)
list(l.parameters())

[Parameter containing:
tensor([[-0.0473, 0.0679],
[-0.6429, -0.5586]], requires_grad=True)]

input = torch.tensor([[1.3032, 0.5701], [0.2671, 1.3900]])
o1 = l(input)
o1

tensor([[-0.0230, -1.1563],
[ 0.0817, -0.9482]], grad_fn=<MmBackward>)

loss_fn = nn.CrossEntropyLoss(reduction='sum')
loss = loss_fn(o1, torch.tensor([0, 1]))
# assume that target is [0, 1]

softmax (nn.Softmax(dim=-1)) is,

output after passing input through linear and then softmax would be like,
(w00, w01, w10, w11 are weights)

after passing this through cross entropy loss with reduction=‘sum’, (we do not apply softmax externally, cross entropy has inbuilt softmax),

(since our target was torch.tensor([0, 1]), we considered 0th index in first row, and 1st index in second row)

after taking jacobian and substituting weights with their values,


if we do

loss.backward()
for param in l.parameters():
  print(param.grad)

then we get,

tensor([[-0.1206, 0.8855],
[ 0.1206, -0.8855]])

which matches with our calculated gradient.

if our output after applying linear contained float('-inf'), that is it was something like,

o1 = torch.tensor([[float('-inf'), 2.], [1., 2.]], requires_grad=True)

then after applying cross entropy,

after computing jacobian,

substitute with values,

which matches with,

o1 = torch.tensor([[float('-inf'), 2.], [1., 2.]], requires_grad=True)
loss_fn = nn.CrossEntropyLoss(reduction='sum')
loss = loss_fn(o1, torch.tensor([0, 1]))
loss.backward()
o1.grad

tensor([[-1.0000, 1.0000],
[ 0.2689, -0.2689]])

all the float(‘-inf’) terms vanish

1 Like

Ahh yes, I wasn’t thinking of the subsequent softmax operation. We want the elements of softmax output corresponding to future tokens to be 0, and not the dot product itself. Thanks for the prompt response!

Thanks for the detailed response :slight_smile: May I ask what software/package you used for the symbolic representation in your screenshots? Can you do that with pytorch?

I use sympy for symbolic mathematics.