Torch.transpose is too slow in GPU,slower than CPU

when I write some model about CNN, I found the transpose op is too slow in GPU, even slower than CPU
here is some my test code
test1:

import torch
import time
from torch.autograd import Variable

x = Variable(torch.randn(100,500))
cputimes = []
for sampl in (1000, 10000, 100000, 1000000):
    start = time.time()
    for i in range(sampl):
        y = torch.transpose(x,0,1)
    end = time.time()
    cputimes.append(end-start)
print(cputimes)

x = x.cuda(device_id=2)
gputimes = []
for sampl in (1000, 10000, 100000, 1000000):
    start = time.time()
    for i in range(sampl):
        y = torch.transpose(x,0,1)
    end = time.time()
    gputimes.append(end-start)
print(gputimes)

output:
[0.00479888916015625, 0.047800540924072266, 0.5636386871337891, 4.8213441371917725]
[0.0057294368743896484, 0.056331634521484375, 0.5558302402496338, 5.78531289100647]

test2:

In [16]: torch.cuda.set_device(2)

In [17]: %timeit torch.transpose(torch.FloatTensor(20,100),1,0)
The slowest run took 26.26 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 1.72 µs per loop

In [18]: %timeit torch.transpose(torch.cuda.FloatTensor(20,100),1,0)
The slowest run took 21.21 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.05 µs per loop

It case my model run very slowly,how can I solve it

Hi,

In you example, you could replace the transpose function by any function in torch, you would get the same behavior.
The transpose operation does not actually touches the tensor data and just work on the metadata. The code to do that on cpu and gpu is exactly the same and never touches the gpu. The runtimes that you see in your test is just the overhead of the python loop + calling into c code (in your case the c code does almost nothing).
The gpu version is slightly slower because the cuda library has to get its state before calling the functions which slows it slightly compared to the pure cpu version.

This code sample is slow only because of the python loop which calls c functions. To make it faster, you need to find a way to remove this loop. If in your case, you want to transpose a bunch of matrices, you could for example stack them in a single tensor and then call transpose on this tensor.

1 Like

To add to what @albanD wrote, here’s also a quick demo of Autograd overhead.

class NoOp(torch.autograd.Function):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x

def print_times(x, func, msg):
    start = time.time()
    for i in range(1000000):
        _ = func(x)
    t = time.time() - start
    print("{}: {:.5f}".format(msg, t))

tensor   = torch.randn(100, 500)
ndarray  = tensor.numpy()
variable = Variable(tensor)
    
print_times(tensor,   lambda x: x,             "Python noop")
print_times(ndarray,  lambda x: x.transpose(), "Numpy transpose")
print_times(tensor,   lambda x: x.t(),         "Torch transpose")
print_times(variable, lambda x: NoOp()(x),     "Autograd noop")
print_times(variable, lambda x: x.t(),         "Autograd transpose")

# output:
#
# Python noop: 0.07554
# Numpy transpose: 0.23783
# Torch transpose: 0.49813
# Autograd noop: 1.95098
# Autograd transpose: 3.72835

Actually I changed my code and remove the transpose op, but the model still run slower on GPU than CPU, here is some of my model code .

x = x.unsqueeze(1)
x0 = F.relu(self.conv0_0(x)).squeeze(3)
x1 = F.relu(self.conv0_1(x)).squeeze(3)
x = torch.cat((x0,x1),1)

x = x.unsqueeze(1)
x0 = F.relu(self.conv1_0(x)).squeeze(2)
x1 = F.relu(self.conv1_1(x)).squeeze(2)

x0 = F.max_pool1d(x0, x0.size(2)).squeeze(2)
x1 = F.max_pool1d(x1, x1.size(2)).squeeze(2)
x = torch.cat((x0,x1),1)

self.conv is torch.nn.conv2d,
why it preform badly on GPU

The main reason for that is usually that you are working with such small conv layer and so few data that the overhead of launching the job on the GPU is higher than the computation itself.
If you have a very small net with small inputs, you will see no speedup from using GPUs.