Add rows of Zeros with Mask

Suppose there are two torch tensors:

A=torch.arange(0,10)
B=torch.tensor([0,2,4,7,9])

First, I would like to produce a mask where B is missing values found in sequence A.
Then I would like to apply that mask to another tensor of size:

C=torch.rand(5,2)

So that for every index missing in B, zeros are filled on that row of C. The final output should be of size (10,2).

I could do it with a loop, but the size of these tensors are in the millions of rows. Is there a parallelized method to approach this problem?

The way you described it, you can do something like this.

A=torch.arange(0,10)
B=torch.tensor([0,2,4,7,9])

mask = torch.all(A - B.unsqueeze(0).T, dim=0)

C = torch.rand(5, 2)
D = torch.zeros(D.shape[0] + mask.count_nonzero(), 2)

D[mask.logical_not(), :] = C
print(C)
print(D)
# Output:
# C
tensor([[0.8500, 0.1220],
        [0.8756, 0.0267],
        [0.6929, 0.1234],
        [0.4816, 0.5664],
        [0.0311, 0.8782]])
# D
tensor([[0.8500, 0.1220],
        [0.0000, 0.0000],
        [0.8756, 0.0267],
        [0.0000, 0.0000],
        [0.6929, 0.1234],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.4816, 0.5664],
        [0.0000, 0.0000],
        [0.0311, 0.8782]])

But if you know the total size, and B has the row index where you want the information, then you can do something like this and get the same result.

E = torch.zeros(10, 2)

E[B] = C
print(E)
1 Like

Perfect. But I just suggest you edit this line to:

D = torch.zeros(C.shape[0] + mask.count_nonzero(), 2)

1 Like

Seems that this line maxes out memory:

mask = torch.all(A - B.unsqueeze(0).T, dim=0)

Try with:

import torch

A=torch.arange(0,2000000)
B=torch.arange(0,1000000)
C=torch.rand(1000000,4)

mask = torch.all(A - B.unsqueeze(0).T, dim=0)
D = torch.zeros(C.shape[0] + mask.count_nonzero(), 4)

D[mask.logical_not(), :] = C
print(C)
print(D)

I would then suggest to use the other method.

import torch

A=torch.arange(0,2000000)
B=torch.arange(0,1000000)
C=torch.rand(1000000,4)

D = torch.zeros(A.shape[0], 4)
D[B, :] = C

print(C)
print(D)

Looks like I still need to figure out a way to get the indices, then. In the actual example, A can also have values missing.

I think what I’ll do, instead, is just upload the B data to the A table in Mysql and then pull it all in that way. Thanks.

1 Like