Why torch.take is tremendously slower than torch.index_select with two reshapes?

I have a huge model where I have 2D look-up-table inside. The model itself is a huge gan so that LUT performance is basically neglected.

previously, I had it implemented like this:

flat_input = input.reshape(-1)
output = torch.index_select(lut, 0, flat_input).reshape(input.shape)

I thought that it is too many ops and introduce a cleaner version:

output = torch.take(lut, input)

And it made my model 30% slower, while before this change, that lut op was basically neglectable in terms of performance impact.

Can somebody explain the nature of such a slowdown?

Could you post the shapes of both tensors, so that we could profile it?
Using these shapes:

x = torch.randn(100, 100, 100)
idx = torch.randint(0, x.nelement(), (1000,))

yield approx. the same timing for CPU and CUDA runs in the latest nightly binary.

Thanks for the reply!

The shapes grow progressively from 10, 1024, 32, 32 to 10, 1024, 512, 512. And lut is of size 38.

I train with CUDA and I think that backprop is to blame here, CPU inference is comparable in my experience as well.