I have a huge model where I have 2D look-up-table inside. The model itself is a huge gan so that LUT performance is basically neglected.
previously, I had it implemented like this:
flat_input = input.reshape(-1) output = torch.index_select(lut, 0, flat_input).reshape(input.shape)
I thought that it is too many ops and introduce a cleaner version:
output = torch.take(lut, input)
And it made my model 30% slower, while before this change, that lut op was basically neglectable in terms of performance impact.
Can somebody explain the nature of such a slowdown?