# Split a tensor into two groups using a mask matrix

Given the tensor `A` with shape `(b,n,d)` and another matrix `M` with shape `(b,n)`, what is the most efficient way to split `A` into two tensors `B` and `C` with shapes `(b,l,d)` and `(b,m,d)`, respecitvely, where `n=l+m` and the row `A[k,i,:]` should be included in tensor `B` if `M[k,i] == 1`, otherwise it should be included in `C`.

The trivial solution is to have a for loop and selects the row in `A` based on the corresponding values in `M`. Is there any way to do this splitting in parallel?

1 Like

I guess you could use `scatter_` for this operation. Could you post your loop approach, so that we could check how exactly you are currently indexing the tensors?

Thanks for your response. Here’s the loop approach:

``````        B = []
C = []
for i in range(M.shape):
b = []
c = []
for j in range(M.shape):
if M[i,j] == 1:
c += [A[i,j]]
else:
b += [A[i,j]]
B += [torch.stack(b, dim=0)]
C += [torch.stack(c, dim=0)]
``````

There might be some padding required to convert B and C to tensors, but this is the general idea of the loop approach.

Assuming `M` only contains zeros and ones (or you could `clamp` it to these values), you could use it as a mask as seen here:

``````# setup
b, n, d = 2, 10, 4
A = torch.randn(b, n, d)
M = torch.randint(0, 2, (b, n))

B = []
C = []
for i in range(M.shape):
b = []
c = []
for j in range(M.shape):
if M[i,j] == 1:
c += [A[i,j]]
else:
b += [A[i,j]]
B += [torch.stack(b, dim=0)]
C += [torch.stack(c, dim=0)]

B = torch.cat(B)
C = torch.cat(C)

# alternative
B_ = A[~M.bool()]
C_ = A[M.bool()]

print((B == B_).all())
> tensor(True)
print((C == C_).all())
> tensor(True)
``````

In this case, how to re-concatenate again with B, C, M to A?
Like below:

``````A_again = mask_concatenate(B, C, M)
assert (A == A_again).all() == True
``````