Strange behavior with long kernel conv1d

I’m running into a weird performance issue with some blocks after a long kernel conv1d. In the example below, the simple operation of picking random index from a batch is 1.44s with a long kernel. The conv1d kernel performance is fine. It’s only the subsequent block that is problematic. This is experiment was done with v100 GPU on Ubuntu 18.04, cuda 11.1. What’s the recommended fix?

import torch
from time import time

print('Torch version {}'.format(torch.__version__))


def test_conv_timing(K):
    print('Kernel: {}'.format(5))
    x = torch.randn(1, 512, 160000, dtype=torch.float32, device='cuda')
    kernel = torch.randn(512, 1, K, dtype=torch.float32, device='cuda')
    for i in range(10):
        dummy = torch.nn.functional.conv1d(x, kernel, groups=512)
        start_time = time()
        idx = torch.randint(0, 512, (1,))
        y = x[:, idx, :]
        print('{} elapsed time {}s'.format(i, time() - start_time))


test_conv_timing(K=5)
test_conv_timing(K=20000)

Output:

Kernel: 5
0 elapsed time 0.018301963806152344s
1 elapsed time 0.0017490386962890625s
2 elapsed time 0.0017266273498535156s
3 elapsed time 0.0017251968383789062s
4 elapsed time 0.001725912094116211s
5 elapsed time 0.001722574234008789s
6 elapsed time 0.0017278194427490234s
7 elapsed time 0.0017235279083251953s
8 elapsed time 0.0017247200012207031s
9 elapsed time 0.0017256736755371094s
Kernel: 20000
0 elapsed time 1.4438841342926025s
1 elapsed time 1.4291105270385742s
2 elapsed time 1.43113374710083s
3 elapsed time 1.4308841228485107s
4 elapsed time 1.4307475090026855s
5 elapsed time 1.430490255355835s
6 elapsed time 1.4306378364562988s
7 elapsed time 1.430649995803833s
8 elapsed time 1.4352474212646484s
9 elapsed time 1.4375996589660645s

Just as an update, if I changed the line for random number generated to be on cuda:

idx = torch.randint(0, 512, (1,), device='cuda')

then the long kernel runs as quickly as the short kernel (like 5e-5s). I believe this is the correct thing to do here, but I’m still not comfortable with the change of behavior between long and short kernels in the original code.

Also, I tried with torch version 1.8.1, and even putting the randint line on cuda resulted in 1.4s for the long kernel. Can someone from the PyTorch team confirm this is? Is there a workaround if for some reason I have to stick with 1.8.1? Thanks.

This difference is possibly because the data is pushed to the CPU for the indexing step, which is possibly because the index lives on the CPU. I say this because pushing the data to the CPU before indexing gives comparable times in the two cases:

import torch
from time import time

print('Torch version {}'.format(torch.__version__))


def test_conv_timing(K):
    print('Kernel: {}'.format(K))
    x = torch.randn(1, 512, 160000, dtype=torch.float32, device='cuda')
    kernel = torch.randn(512, 1, K, dtype=torch.float32, device='cuda')
    for i in range(10):
        dummy = torch.nn.functional.conv1d(x, kernel, groups=512)
        x.to('cpu')        
        start_time = time()
        idx = torch.randint(0, 512, (1,))
        y = x[:, idx, :] 
        # This y lives on the CPU. To get the same effect as the original code  
        # we should push this y and the x back on to the GPU.    
        print('{} elapsed time {}s'.format(i, time() - start_time))


test_conv_timing(K=5)
test_conv_timing(K=20000)

This is surprising to me since I would expect Torch to throw an exception if it is asked to do an operation involving entities which live on different types of processors. I guess here the implementation does the back and forth transfer silently in the background.

I haven’t checked the documentation; this may be explicitly stated somewhere.

I totally agree that if index is in the cpu, which it is by default, then using it to slice a cuda tensor might be a cause of slow down. But this slowdown should have a consistent time, irrespective of the length of conv1d preceding it.

Also, in pytorch 1.8.1, even if the index was placed on cuda did not fix the issue. I’m just wonder if I’m running across a weird corner case of extremely long kernels that is not encountered by the testing. The fact that it behaves differently across versions is worrisome.