Egregious inefficiency in torch.combinations()

Hi All!

This post adds context and details to the existing pytorch github
issue 41325 and a recent forum thread (see below).

torch.combinations() breaks on problems of moderate size and
is vastly slower than other implementations, such as python’s
itertools.combinations().

One oddity (as noted in the github issue) is that
torch.combinations (elements, 1) can work, where
the essentially equivalent computation,
torch.combinations (elements, len (elements) - 1), fails.

Consider this script:

import torch
print (torch.__version__)

import itertools

print ('pytorch:')
for  k in range (1, 8):
    print ('k, C(2k, k):', k, len (torch.combinations (torch.arange (2 * k), k)))

print ('python:')
for  k in range (1, 11):
    print ('k, C(2k, k):', k, len (list (itertools.combinations (range (2 * k), k))))

print ('specific cases:')

try:
    print ('trying torch.combinations (torch.arange (16), 8) ...')
    print (len (torch.combinations (torch.arange (16), 8)))
except (Exception) as e:
    print ('e:', e)

try:
    print ('trying torch.combinations (torch.arange (16), 1) ...')
    print (len (torch.combinations (torch.arange (16), 1)))
except (Exception) as e:
    print ('e:', e)

try:
    print ('trying itertools.combinations (range (16), 15) ...')
    print (len (list (itertools.combinations (range (16), 15))))
    print ('trying torch.combinations (torch.arange (16), 15) ...')
    print (len (torch.combinations (torch.arange (16), 15)))
except (Exception) as e:
    print ('e:', e)

print ('combinations should accept combinations of zero elements:')

try:
    print ('trying itertools.combinations (range (5), 0) ...')
    print (list (itertools.combinations (range (5), 0)))
    print ('trying torch.combinations (torch.arange (5), 0) ...')
    print (torch.combinations (torch.arange (5), 0))
except (Exception) as e:
    print ('e:', e)

And its output:

1.10.0
pytorch:
<string>:12: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  /opt/conda/conda-bld/pytorch_1634272128894/work/aten/src/ATen/native/TensorShape.cpp:2157.)
k, C(2k, k): 1 2
k, C(2k, k): 2 6
k, C(2k, k): 3 20
k, C(2k, k): 4 70
k, C(2k, k): 5 252
k, C(2k, k): 6 924
k, C(2k, k): 7 3432
python:
k, C(2k, k): 1 2
k, C(2k, k): 2 6
k, C(2k, k): 3 20
k, C(2k, k): 4 70
k, C(2k, k): 5 252
k, C(2k, k): 6 924
k, C(2k, k): 7 3432
k, C(2k, k): 8 12870
k, C(2k, k): 9 48620
k, C(2k, k): 10 184756
specific cases:
trying torch.combinations (torch.arange (16), 8) ...
e: [enforce fail at CPUAllocator.cpp:68] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 34359738368 bytes. Error code 12 (Cannot allocate memory)
trying torch.combinations (torch.arange (16), 1) ...
16
trying itertools.combinations (range (16), 15) ...
16
trying torch.combinations (torch.arange (16), 15) ...
e: [enforce fail at CPUAllocator.cpp:68] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 1152921504606846976 bytes. Error code 12 (Cannot allocate memory)
combinations should accept combinations of zero elements:
trying itertools.combinations (range (5), 0) ...
[()]
trying torch.combinations (torch.arange (5), 0) ...
e: Expect a positive number, but got 0

I didn’t put timings into the script, but you can see by the wall clock
that torch.combinations() is ridiculously slow.

(I get essentially the same results on 1.9.0 and today’s nightly build,
1.11.0.dev20211106.)

Here is the github issue:

And here is the forum thread:

Best.

K. Frank

1 Like