I’m looking for the most pythonic way to select all entries of a Tensor whose indices full fill a certain condition. I am aware of
torch.where(), but the condition here acts on the value of each entry, not its indices.
E.g. I would want to select all entries where the indices
j along the first two dimensions fulfill
i // k > j // k. What is the best way to achieve this?
This is not a PyTorch-like question but if I were you, I try to do that in this way
M, N = 10, 11
a = torch.rand(M, N)
mask = torch.tril(torch.ones(M, N), diagonal=-1)
output = torch.masked_select(a, mask)
Mmmh, that doesn’t really answer my question. I’m aware that I can use
where() to select entries based on some mask. And you’re right that for the special case of
k = 1 I could use
triu() to get a mask in very few lines of code.
What I’m asking though if there is a more general way to either select entries directly, or at least construct a boolean mask elegantly for an arbitrary function of the indices
i, j, ... along a tensor’s dimensions.
Sry, I don’t know that magical thing…
Wait for somebody who can help