Filtering out null rows of a 3-D Tensor

Hi there,

I am having some problem trying to convert the following function into a function only manipulating tensors.

def valid_sequence_output(sequence_output, valid_mask, attention_mask):
    batch_size, max_len, feat_dim = sequence_output.shape
    valid_output = torch.zeros(batch_size, max_len, feat_dim, dtype=torch.float32,
                               device='cuda' if torch.cuda.is_available() else 'cpu')
    valid_attention_mask = torch.zeros(batch_size, max_len, dtype=torch.long,
                                       device='cuda' if torch.cuda.is_available() else 'cpu')
    for i in range(batch_size):
        jj = -1
        for j in range(max_len):
            if valid_mask[i][j].item() == 1:
                jj += 1
                valid_output[i][jj] = sequence_output[i][j]
                valid_attention_mask[i][jj] = attention_mask[i][j]
    return valid_output, valid_attention_mask

where the input tensors can be created as follow:

size = ((2,5,2))
sequence_output = torch.randint(0, 250, size=size)
valid_mask = torch.randint(0, 2, size=size[:2])
attention_mask = torch.randint(0, 2, size=size[:2])

I basically aim at “bubbling up” the non null rows of sequence_output and attention_mask to the top. The closest I could get was by doing something like:

torch.where(valid_mask.unsqueeze(-1) == 1, sequence_output, torch.zeros_like(sequence_output))

but then I couldn’t find a way to “bubble up” the non-null rows.

If somebody has a suggestion on how to do this, I would really appreciate :smiley:

Cheers,
Jules