Fixing torch.combinations()'s performance bug

Hi All!

The computational complexity of pytorch’s algorithm for
torch.combinations() is sufficiently poor that it should
be considered a bug and be fixed.

Summary:

Pytorch’s torch.combinations() requires excessive time
and memory, even for problems of modest size. It is especially
bad for larger values of r, the number of elements in each
combination.

Python’s itertools.combinations() always outperforms
torch.combinations(), and can do so dramatically, even
on moderately-sized problems. A one-line work-around that
delegates generation of the combinations to
itertools.combinations():

torch.tensor (list (itertools.combinations (input.tolist(), r)), device = input.device)

works well, and should be used to fix torch.combinations()
unless and until an ATen (and possibly also CUDA) implementation
is developed that outperforms itertools.combinations().

A simple combinations() implementation that uses
python loops and pytorch tensor operations illustrates that
torch.combinations()'s bad computational complexity is
not inherent in the task of enumerating combinations. This
algorithm outperforms torch.combinations() for larger values
of r, but always underperforms itertools.combinations().

itertools.combinations() is packaged as a generator, rather
than returning the full set of combinations all at once. Pytorch should
consider implementing a generator version of combinations() for
those cases in which the full set of combinations is not needed, or
need not be materialized simultaneously.

(This post is a follow-up to the forum thread
Egregious inefficiency in torch.combinations()
and to github issue 41325.)

Details:

In torch.combinations (input, r), input is a 1D vector of length
n and r is the number of elements of input in each combination.

The culprit in torch.combinations()'s poor performance is line 51
(and also line 52) of itertools.cpp:

std::vector<Tensor> grids = at::meshgrid(std::vector<Tensor>(r, self));

Here, grids, the full Cartesian product of r copies of input,
is materialized (using meshgrid()). grids consists of r * n**r
elements, and can be dramatically larger than the number of
elements in the full set of combinations that is returned
(namely r * C(n, r) = r * n! / (r! * (n - r)!)).

The ratio by which the number of elements in grids exceeds
that in the actual set of combinations is, for some examples:

n = 10, r  = 1:   1.0
n = 10, r  = 5:   3.97e2
n = 10, r = 10:   1.00e10
n = 20, r = 10:   5.54e7
n = 20, r = 20:   1.05e26

(The same performance problems and analysis apply to the
with_replacement = True case, although not as dramatically.)

The possibility of using itertools.combinations() as a
replacement for torch.combinations() is made clear in
the documentation for torch.combinations()

Quoting:

Returns

A tensor equivalent to converting all the input tensors into lists, do
itertools.combinations or itertools.combinations_with_replacement
on these lists, and finally convert the resulting list into tensor.

Loops-with-tensors and itertools.combinations() versions of
combinations(), together with timings, are given in the following
script:

import torch
print (torch.__version__)
print (torch.cuda.get_device_name())

import sys
print (sys.version)

import math
import itertools
import timeit

def combinationsA (inp, r):   # loops and tensor operations
    assert  inp.dim() == 1
    assert  r >= 0  and  r <= inp.shape[0]
    n = inp.shape[0]
    nComb = int (math.factorial (n) / (math.factorial (r) * math.factorial (n - r)))
    ind = torch.empty (nComb, r, dtype = torch.int64)
    ind[0] = torch.arange (r)
    for  i in range (1, nComb):
        ind[i] = ind[i - 1]
        l = n
        for  j in range (r - 1, -1, -1):
            ind[i, j] += 1
            if  ind[i, j] < l:  break
            l -= 1
        for  k in range (j + 1, r):  ind [i, k] = ind[i, k - 1] + 1
    return inp[ind]

def combinationsB (inp, r):   # itertools with conversion from and to tensors
    assert  inp.dim() == 1
    assert  r >= 0  and  r <= inp.shape[0]
    return torch.tensor (list (itertools.combinations (inp.tolist(), r)), device = inp.device)

def combinationsC (inp, r):   # itertools to compute indices
    assert  inp.dim() == 1
    assert  r >= 0  and  r <= inp.shape[0]
    return inp[torch.tensor (list (itertools.combinations (range (inp.shape[0]), r)), device = inp.device)]

print ('check correctness:')
allGood = True
for  i in range (10):
    inp = torch.rand (i)
    for  j in range (1, i):
        for  device in ('cpu', 'cuda'):
            inpd = inp.to (device)
            cmb = torch.combinations (inp, j)
            cmbA = combinationsA (inp, j)
            allGood &= torch.all (cmb.eq (cmbA)).item()
            cmbB = combinationsB (inp, j)
            allGood &= torch.all (cmb.eq (cmbB)).item()
            cmbC = combinationsC (inp, j)
            allGood &= torch.all (cmb.eq (cmbC)).item()

print ('allGood:', allGood)

print ('timings:')
print ('   N -- pytorch "native" torch.combinations()')
print ('   A -- python loops and pytorch tensor operations')
print ('   B -- pure python itertools (with conversion from and to pytorch tensors)')
print ('   C -- python itertools to compute combination indices')
number = 10   # timeit number
repeat = 3    # timeit repeat
n = 6
r = 1
print ('C(6, 1), timeit "number" =', number)
inp = torch.rand (n)
inpg = inp.cuda()
print ('CPU:')
print ('N:', timeit.repeat ('torch.combinations (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('CUDA:')
print ('N:', timeit.repeat ('torch.combinations (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
n = 6
r = 3
print ('C(6, 3), timeit "number" =', number)
inp = torch.rand (n)
inpg = inp.cuda()
print ('CPU:')
print ('N:', timeit.repeat ('torch.combinations (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('CUDA:')
print ('N:', timeit.repeat ('torch.combinations (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
n = 6
r = 6
print ('C(6, 6), timeit "number" =', number)
inp = torch.rand (n)
inpg = inp.cuda()
print ('CPU:')
print ('N:', timeit.repeat ('torch.combinations (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('CUDA:')
print ('N:', timeit.repeat ('torch.combinations (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
n = 10
r = 1
print ('C(10, 1), timeit "number" =', number)
inp = torch.rand (n)
inpg = inp.cuda()
print ('CPU:')
print ('N:', timeit.repeat ('torch.combinations (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('CUDA:')
print ('N:', timeit.repeat ('torch.combinations (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
n = 10
r = 5
print ('C(10, 5), timeit "number" =', number)
inp = torch.rand (n)
inpg = inp.cuda()
print ('CPU:')
print ('N:', timeit.repeat ('torch.combinations (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('CUDA:')
print ('N:', timeit.repeat ('torch.combinations (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
n = 10
r = 8
print ('C(10, 8), timeit "number" =', number)
inp = torch.rand (n)
inpg = inp.cuda()
print ('CPU:')
print ('N:', timeit.repeat ('torch.combinations (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('CUDA:')
print ('N:', timeit.repeat ('torch.combinations (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
n = 16
r = 8
print ('C(16, 8), timeit "number" =', number, ' ("native" fails with OOM)')
inp = torch.rand (n)
inpg = inp.cuda()
print ('CPU:')
# print ('N:', timeit.repeat ('torch.combinations (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('CUDA:')
# print ('N:', timeit.repeat ('torch.combinations (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
n = 24
r = 12
print ('C(24, 12), timeit "number" =', number, ' ("native" fails with OOM)')
inp = torch.rand (n)
inpg = inp.cuda()
print ('CPU:')
# print ('N:', timeit.repeat ('torch.combinations (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inp, r).numel()', globals = globals(), number = number, repeat = repeat))
print ('CUDA:')
# print ('N:', timeit.repeat ('torch.combinations (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('A:', timeit.repeat ('combinationsA (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('B:', timeit.repeat ('combinationsB (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))
print ('C:', timeit.repeat ('combinationsC (inpg, r).numel()', 'torch.cuda.synchronize()', globals = globals(), number = number, repeat = repeat))

Here is its output:

1.10.0
GeForce GTX 1050 Ti
3.8.3 (default, May 19 2020, 18:47:26)
[GCC 7.3.0]
check correctness:
<string>:50: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  /opt/conda/conda-bld/pytorch_1634272128894/work/aten/src/ATen/native/TensorShape.cpp:2157.)
allGood: True
timings:
   N -- pytorch "native" torch.combinations()
   A -- python loops and pytorch tensor operations
   B -- pure python itertools (with conversion from and to pytorch tensors)
   C -- python itertools to compute combination indices
C(6, 1), timeit "number" = 10
CPU:
N: [0.00039273800211958587, 0.00035531500179786235, 0.000345911001204513]
A: [0.0021452640066854656, 0.002125437997165136, 0.0021074380056234077]
B: [0.00011617399286478758, 0.00010351900709792972, 0.00010330999793950468]
C: [0.00023444498947355896, 0.00023154899827204645, 0.00022320599236991256]
CUDA:
N: [0.0014755239972146228, 0.0012697580095846206, 0.0012903540045954287]
A: [0.0025693880015751347, 0.0025139990029856563, 0.0025180049997288734]
B: [0.0005950019985903054, 0.0005792269948869944, 0.0005921980045968667]
C: [0.0006349580071400851, 0.000624655993306078, 0.0006288529984885827]
C(6, 3), timeit "number" = 10
CPU:
N: [0.000934002993744798, 0.0009012250084197149, 0.0008866400021361187]
A: [0.012340559987933375, 0.012262326999916695, 0.012249751001945697]
B: [0.00016576000780332834, 0.00015234800230246037, 0.00015243100642692298]
C: [0.00029560699476860464, 0.0002838600048562512, 0.0002834049955708906]
CUDA:
N: [0.0032089009910123423, 0.003090098995016888, 0.003078762994846329]
A: [0.012742331004119478, 0.012682663000305183, 0.012686592002864927]
B: [0.0006435590039473027, 0.0006267220014706254, 0.0006319009989965707]
C: [0.0006937019934412092, 0.000679793010931462, 0.0006800050032325089]
C(6, 6), timeit "number" = 10
CPU:
N: [0.012257819005753845, 0.012083017994882539, 0.012073186997440644]
A: [0.000352169998222962, 0.00032322999322786927, 0.0003296549984952435]
B: [0.0001320349983870983, 9.665598918218166e-05, 9.574399155098945e-05]
C: [0.00022611500753555447, 0.0002160379954148084, 0.00021456999820657074]
CUDA:
N: [0.0065253759967163205, 0.006364347005728632, 0.0063888610020512715]
A: [0.0007201859989436343, 0.0006900919979671016, 0.0006889280048198998]
B: [0.0005996480031171814, 0.000582634995225817, 0.0005704919894924387]
C: [0.0006277069915086031, 0.00061421500868164, 0.0006145149964140728]
C(10, 1), timeit "number" = 10
CPU:
N: [0.00036994199035689235, 0.0003493999975034967, 0.0003444360045250505]
A: [0.0035874680033884943, 0.003516295997542329, 0.0035193749936297536]
B: [0.00012491500820033252, 0.00011192800593562424, 0.00011186499614268541]
C: [0.00024435098748654127, 0.0002341840008739382, 0.00023348799732048064]
CUDA:
N: [0.001301341995713301, 0.0012528450024547055, 0.0012669840070884675]
A: [0.003971571990405209, 0.003967201002524234, 0.003950157988583669]
B: [0.0005974770028842613, 0.0005822969978908077, 0.0005810809961985797]
C: [0.0006445110047934577, 0.0006386099994415417, 0.0006419359997380525]
C(10, 5), timeit "number" = 10
CPU:
N: [0.012096004997147247, 0.011976242007222027, 0.011920516000827774]
A: [0.17747894099738915, 0.17712419899180532, 0.1768949889956275]
B: [0.001098584005376324, 0.0010706149914767593, 0.00108613699558191]
C: [0.0014568490005331114, 0.0014570009952876717, 0.001449190007406287]
CUDA:
N: [0.005448563999379985, 0.005311100001563318, 0.005315380010870285]
A: [0.1791406120028114, 0.17859279799449723, 0.17860569400363602]
B: [0.0015968630032148212, 0.0015729150036349893, 0.0015651840076316148]
C: [0.0018307700083823875, 0.0017917060031322762, 0.0017883199907373637]
C(10, 8), timeit "number" = 10
CPU:
N: [48.963757300996804, 48.98913558899949, 48.88807302899659]
A: [0.06408211300731637, 0.06367387500358745, 0.0637716810015263]
B: [0.0003515210119076073, 0.00032244900648947805, 0.00032212100632023066]
C: [0.0005413449980551377, 0.0005208069924265146, 0.0005347499973140657]
CUDA:
N: [1.6645929790101945, 1.5955002770060673, 1.5954844549996778]
A: [0.0643903829914052, 0.06431056399014778, 0.06436278599721845]
B: [0.0008272179984487593, 0.0007921679934952408, 0.0007913290028227493]
C: [0.000913231007871218, 0.000904972999705933, 0.0008972539944807068]
C(16, 8), timeit "number" = 10  ("native" fails with OOM)
CPU:
A: [9.492298726996523, 9.504841014000704, 9.513660684999195]
B: [0.06738795001001563, 0.06704400399758015, 0.06704581099620555]
C: [0.08578118600416929, 0.08568321299389936, 0.08573524499661289]
CUDA:
A: [9.518183541003964, 9.517515347004519, 9.520561250988976]
B: [0.06815099400409963, 0.06779349900898524, 0.06774643198878039]
C: [0.08564549899892882, 0.0854689700063318, 0.08584213600261137]
C(24, 12), timeit "number" = 10  ("native" fails with OOM)
CPU:
A: [2042.499984451002, 2045.9150624980102, 2045.8145195090037]
B: [19.31995862199983, 19.31601057900116, 19.318088586005615]
C: [25.900873358012177, 25.906365320988698, 25.903081741998903]
CUDA:
A: [2050.0636496259976, 2050.3741932809935, 2051.897835465992]
B: [19.530357536001247, 19.53001726299408, 19.53024267400906]
C: [25.719029665007838, 25.722186693994445, 25.76425094100705]

Best.

K. Frank