erip
(Elijah Rippeth)
1
I have a tensor with values I’d like to ignore. Sometimes these values appear in the middle of the tensor. For example, given
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 3, 5]])
I might want to ignore 3. Is there some way to filter out the 3s while padding as necessary? For instance,
tensor([[1, 2, 4, 5],
[1, 2, 5, -1]])
My thought is to use torch.where
, but this only updates the tensor’s ignore-value without changing the shape.
This approach would need a loop, but unfortunately I cannot come up with another (better) approach right now:
# setup
x = torch.tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 3, 5]])
# create mask for valid values
mask = x!=3
# calculate max. valid elements per row
nb_elements = mask.sum(1).max()
# create output tensor with padding values
pad_val = -1
out = torch.full((x.size(0), nb_elements), fill_value=pad_val)
# assign values to out
for idx, (x_, m) in enumerate(zip(x, mask)):
tmp = torch.masked_select(x_, m)
out[idx, :len(tmp)] = tmp
print(out)
> tensor([[ 1, 2, 4, 5],
[ 1, 2, 5, -1]])