Thanks for your response. Here’s the loop approach:
B = []
C = []
for i in range(M.shape[0]):
b = []
c = []
for j in range(M.shape[1]):
if M[i,j] == 1:
c += [A[i,j]]
else:
b += [A[i,j]]
B += [torch.stack(b, dim=0)]
C += [torch.stack(c, dim=0)]
There might be some padding required to convert B and C to tensors, but this is the general idea of the loop approach.