Hi Whoab!
First, as a practical matter, looping over a small number of items (in your
example, 3 * 2
) won’t be very expensive. If the size of the permutation
(100
in your example) is relatively large (which probably means rather
larger than the 100
in your example), the cost of the permutation will
dwarf the cost of the loop.
So, as a practical matter, I wouldn’t worry about it (unless you can show
with actual timings that it matters).
Having said that, here is a “vectorized” method that introduces a
different source of computational inefficiency.
Note the randperm()
should have a cost of O (n), where n is the
length of the permutation. My sorting approach (below) has a cost
of O (n log (n)) (because of the sort()
). So for a big loop, but a
smallish n
, it should run faster, but for large enough n
, it will run
more slowly.
Note that using randint()
has a (very low) chance of generating
duplicates that would (very slightly) bias the distribution of the
random permutations. multinomial()
(with its default of
replacement = False
) will not generate duplicates, but is probably
more expensive.
>>> torch.__version__
'1.7.1'
>>> _ = torch.manual_seed (2021)
>>> dim1 = 2
>>> dim2 = 2
>>> nPerm = 5
>>> torch.randint (torch.iinfo (torch.int64).max, (dim1, dim2, nPerm)).argsort()
tensor([[[0, 2, 1, 3, 4],
[0, 3, 1, 2, 4]],
[[1, 4, 0, 3, 2],
[0, 3, 1, 4, 2]]])
>>> torch.multinomial (torch.ones (dim1 * dim2 * nPerm), dim1 * dim2 * nPerm).reshape ((dim1, dim2, nPerm)).argsort()
tensor([[[0, 3, 1, 2, 4],
[0, 4, 1, 2, 3]],
[[4, 1, 2, 3, 0],
[2, 4, 0, 3, 1]]])
Best.
K. Frank