@pfloat’s solution with torch.repeat_interleave
is quite interesting. However, I found something simple.
Upper triangular matrix
t = 8
tri = torch.triu( torch.ones(8,8), diagonal=1 )
tri
tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 0., 0., 0., 0.]])
Close but if the idx tensor says 3 then index 3 must also be masked.
tri.fill_diagonal_(1)
tri
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
[0., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 0., 1.]])
Now this upper triangular matrix can be referenced to create a mask
tri[3]
tensor([0., 0., 0., 1., 1., 1., 1., 1.])
idx = torch.tensor([4,1,2,5,3,6])
tri[idx]
tensor([[0., 0., 0., 0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1.]])
Since we only need 0 or 1 we can make to dtype bool or byte and this way the method can work for large values of t
.