Efficiently Combine Two Tensors Based on a Boolean Mask

Hi! I have two tensors:

  • tensor_a ([batch_size, seq_len_a, embedding_dim])
  • tensor_b ([batch_size, seq_len_b, embedding_dim])

The total sequence length is seq_len_total = seq_len_a + seq_len_b. I also have a boolean tensor mask ([batch_size, seq_len_total]) where True corresponds to positions for tensor_a and False corresponds to positions for tensor_b.

How can I efficiently combine tensor_a and tensor_b into a single tensor of shape ([batch_size, seq_len_total, embedding_dim]) using the mask.

I’m unsure if I understand your use case correctly, but direct indexing with the mask should work:

batch_size, seq_len_a, seq_len_b, embedding_dim = 2, 3, 4, 5

a = torch.ones(batch_size, seq_len_a, embedding_dim)
b = torch.ones(batch_size, seq_len_b, embedding_dim) * 2

# create mask
mask = torch.cat((torch.tensor([True]*seq_len_a*batch_size), torch.tensor([False]*seq_len_b*batch_size)))
mask = mask[torch.randperm(len(mask))]
mask = mask.view(batch_size, -1)

out = torch.zeros(batch_size, seq_len_a+seq_len_b, embedding_dim)
out[mask] = a.view(-1, embedding_dim)
out[~mask] = b.view(-1, embedding_dim)
1 Like