Seq2seq masked backprop

A: a 5x11x4 tensor where there are 5 sentences in a batch and 11 is the vocab size for log softmax.
B: a 5x4 tensor which represents the target sentence

I have the lengths of sentences in my batch to be 3,2,1,4 [Max is 4]

Now i want to prevent backprop happening in the portion of sentences beyond the length.

I can think of a way:

I can have a 5x11x4 tensor constructed such that the third dimension is 1 for the length of sentence times and 0 otherwise.

Then i can multiply this element wise with my original tensor and then apply the NLLloss function. This way backprop cannot affect the gradient in those elements.

Does this seem fine or is any other method recommended?

Also ,Even if I calculate the loss matrix first and then make the relevant entries 0 before summing it up, it would be correct right? Thanks

The trick is to build a submatrix with only the entries that correspond to actual words. You can reshape the array to flatten the words and sentences, and only keep the subset. Then the nll only applies to those entries and backprop will do the correct thing.

1 Like

Even if I flatten the array , how can I eliminate the other entries as earlier it was 5x11x4 . Now it will be 55x4 . Now out of these I want some sentences as 4 is the maximum length. So out of the dimension of 4 I want 2,3,3,4,3 for the 5 sentences in my batch. Do you have an example here please? Thanks

I understand your idea ,but am confused about the approach as shown in my comment below :slight_smile:

Sorry, to be clearer, you are flattening the 1st and 3rd dimensions together, then selecting the valid indices to keep. Look at torch.take(), which can be applied after you flatten the array. You’ll have to create a second array with the valid indices to use it.

1 Like

So you mean, if originally it is 5x11x4 and 5x1x4 is the targets. I can have 20x11 and 20x1 and then if i have the first sentence of 3 words and second of 4 , select indices 0,1,2,4,5,6,7 from both vectors [These need to be calculated based on sentence lengths ]and apply the loss function ? Thanks a lot ! :slight_smile: