Hi there,
I am having some problem trying to convert the following function into a function only manipulating tensors.
def valid_sequence_output(sequence_output, valid_mask, attention_mask):
batch_size, max_len, feat_dim = sequence_output.shape
valid_output = torch.zeros(batch_size, max_len, feat_dim, dtype=torch.float32,
device='cuda' if torch.cuda.is_available() else 'cpu')
valid_attention_mask = torch.zeros(batch_size, max_len, dtype=torch.long,
device='cuda' if torch.cuda.is_available() else 'cpu')
for i in range(batch_size):
jj = -1
for j in range(max_len):
if valid_mask[i][j].item() == 1:
jj += 1
valid_output[i][jj] = sequence_output[i][j]
valid_attention_mask[i][jj] = attention_mask[i][j]
return valid_output, valid_attention_mask
where the input tensors can be created as follow:
size = ((2,5,2))
sequence_output = torch.randint(0, 250, size=size)
valid_mask = torch.randint(0, 2, size=size[:2])
attention_mask = torch.randint(0, 2, size=size[:2])
I basically aim at “bubbling up” the non null rows of sequence_output
and attention_mask
to the top. The closest I could get was by doing something like:
torch.where(valid_mask.unsqueeze(-1) == 1, sequence_output, torch.zeros_like(sequence_output))
but then I couldn’t find a way to “bubble up” the non-null rows.
If somebody has a suggestion on how to do this, I would really appreciate
Cheers,
Jules