Split a tensor into two groups using a mask matrix

Given the tensor A with shape (b,n,d) and another matrix M with shape (b,n), what is the most efficient way to split A into two tensors B and C with shapes (b,l,d) and (b,m,d), respecitvely, where n=l+m and the row A[k,i,:] should be included in tensor B if M[k,i] == 1, otherwise it should be included in C.

The trivial solution is to have a for loop and selects the row in A based on the corresponding values in M. Is there any way to do this splitting in parallel?

1 Like

I guess you could use scatter_ for this operation. Could you post your loop approach, so that we could check how exactly you are currently indexing the tensors?

Thanks for your response. Here’s the loop approach:

        B = []
        C = []
        for i in range(M.shape[0]):
            b = []
            c = []
            for j in range(M.shape[1]):
                if M[i,j] == 1:
                    c += [A[i,j]]
                else:
                    b += [A[i,j]]
            B += [torch.stack(b, dim=0)]
            C += [torch.stack(c, dim=0)]

There might be some padding required to convert B and C to tensors, but this is the general idea of the loop approach.

Assuming M only contains zeros and ones (or you could clamp it to these values), you could use it as a mask as seen here:

# setup
b, n, d = 2, 10, 4
A = torch.randn(b, n, d)
M = torch.randint(0, 2, (b, n))

# your approach
B = []
C = []
for i in range(M.shape[0]):
    b = []
    c = []
    for j in range(M.shape[1]):
        if M[i,j] == 1:
            c += [A[i,j]]
        else:
            b += [A[i,j]]
    B += [torch.stack(b, dim=0)]
    C += [torch.stack(c, dim=0)]

B = torch.cat(B)
C = torch.cat(C)

# alternative
B_ = A[~M.bool()]
C_ = A[M.bool()]

print((B == B_).all())
> tensor(True)
print((C == C_).all())
> tensor(True)

In this case, how to re-concatenate again with B, C, M to A?
Like below:

A_again = mask_concatenate(B, C, M)
assert (A == A_again).all() == True