s1ddok
(Andrey Volodin)
1
I have a huge model where I have 2D look-up-table inside. The model itself is a huge gan so that LUT performance is basically neglected.
previously, I had it implemented like this:
flat_input = input.reshape(-1)
output = torch.index_select(lut, 0, flat_input).reshape(input.shape)
I thought that it is too many ops and introduce a cleaner version:
output = torch.take(lut, input)
And it made my model 30% slower, while before this change, that lut op was basically neglectable in terms of performance impact.
Can somebody explain the nature of such a slowdown?
Could you post the shapes of both tensors, so that we could profile it?
Using these shapes:
x = torch.randn(100, 100, 100)
idx = torch.randint(0, x.nelement(), (1000,))
yield approx. the same timing for CPU and CUDA runs in the latest nightly binary.
s1ddok
(Andrey Volodin)
3
Thanks for the reply!
The shapes grow progressively from 10, 1024, 32, 32 to 10, 1024, 512, 512. And lut is of size 38.
I train with CUDA and I think that backprop is to blame here, CPU inference is comparable in my experience as well.