Given a 3d tenzor, say: batch x sentence length x embedding dim
a = torch.rand((10, 1000, 96))
and an array(or tensor) of actual lengths for each sentence
lengths = torch .randint(1000,(10,))
outputs tensor([ 370., 502., 652., 859., 545., 964., 566., 576.,1000., 803.])
How to fill tensor ‘a’ with zeros after certain index along dimension 1 (sentence length) according to tensor ‘lengths’ ?
I want smth like that :
a[ : , lengths : , : ] = 0
One way of doing it (slow if batch size is big enough):
for i, length in enumerate( lengths ):
a[ i , length : , : ] = 0