I am using TransformerEncoder
to train a model on sequences with a single masked value. Let’s say every sequence is S=10
and batch size is N=64
. The number of classes C
is very large, but for sake of this example, let’s make it C=128
. So the output from the encoder is SxNxC=10x64x128
. If I understand correctly, though, the idea is to only count the loss corresponding to the output index of the masked token in the input sequence eg some position S_i=0...9
, so that the actual shape of the output I want to feed the loss function should be 64x128
. I was going to naively use a for loop, but I am thinking there is probably a clever way to do this in a vectorized fashion. I tried playing around with creating my own mask, but so far haven’t come upon a correct solution. Any advice appreciated!