If I have a N, nB dimensional tensor, how do I check that each 1 of N ‘rows’ of the tensor are equal?
If you would like to check if all rows are equal, you could check the first row against all others:
(x[0]==x).all()
Alternatively, you could call torch.unique
on your tensor with dim=0
and check dim0 for the shape:
torch.unique(x, dim=0).size(0) == 1
2 Likes