Huge gpu performance difference between using NVIDIA RXT 2070 and GTX 1080 GPUs

Even though these two GPUs are somewhat close in terms of compute specifications. the GTX 1080 being slightly better. When I train my model on a GTX 1080 GPU powered machine, it takes 0.5 seconds of GPU processing time for a single batch. Whereas an RTX 2070 powered machine takes 9 seconds in average for the same operation. This happens for a variety of models I have trained including pure CNN, CNN+LSTM, CNN+RNN. I am also using the cudnn.benchmark = True and pin_memory = True, for both machines. Both machines have roughly the same CPU and memory capabilities. Any thoughts as to why the RTX machine could be 18x longer for the same operation?

Did you make sure to synchronize the code properly before starting and stopping the timer?
Since CUDA operations are asynchronous, you’ll get wrong profiling results without synchronizations.
Also, are you using the same PyTorch, CUDA and cudnn versions?

Interesting, after investigating the torch versions, it turns out that the GTX machine had torch version == 0.4.1 and the RTX had version == 1.4.0. When I upgraded the GTX version upwards to 1.4.0, I was able to reproduce the bad performance of the RTX.

Why is it the case?

In terms of synchronizing the times, I’m just measuring the time it takes to complete a forward iteration for a batch and also measuring the time that it takes the dataloaders to get the data. So the difference between the two is the time it takes to perform a forward pass in the data. In any case, upgrading the pytorch version makes my epochs go from 5 minutes to complete to 80 minutes.

Actually, I did some more analysis of the timing. It turns out, most of the time spent in the routine is in the loss.backward() operation, not when computing the forward nor when loading the data with .cuda(). Here is a snippet of the code inside my training loop.

for data in dataloader:
    data = data.cuda()
    data = torch.torch.autograd(data, requires_grad=True)
    output = model(data)
    loss  = criterion(output, labels)
    predictions = (output.detach() > 0).int()
    accuracy = (labels.detach().int() == predictions).sum().item()/batch_size
    optimizer.zero_grad()
    loss.backward() ### backward takes 100x longer than forward in torch 1.4.0
    optimizer.step()

Again, if you don’t synchronize the code via torch.cuda.synchronize(), the next “blocking” operation will accumulate the time from all asynchronous calls.
Please synchronize the code before starting and stopping the timers and feel free to post the results.

Sorry, if I didn’t make it clear. I do am using the I am using the synchronize functionality, I just didn’t wanted the code to look like spaghetti. I insert it in between every line shown as follows. I also use an Averager object to access the steady state averages of each time measure.

#Initialize averager objects to accumulate time measures
batch_time = Averager()
data_time = Averager()
cuda_time = Averager()
grad_time = Averager()
forward_time =Averager()
loss_time  = Averager()
prediction_time = Averager()
accuracy_time = Averager()
optimizer_time = Averager()
backward_time = Averager()
#Initialize per epoch timer
end_time = time.time()
for data in dataloader:
    #data loading
    data_time.update(time.time() - end_time)
    #load gpu
    data = data.cuda(async=False)
    torch.cuda.synchronize()
    cuda_time.update(time.time() - end_time)
    #set gradient
    data = torch.torch.autograd(data, requires_grad=True)
    torch.cuda.synchronize()
    grad_time.update(time.time() - end_time)
    #forward call
    output = model(data)
    torch.cuda.synchronize()
    forward_time.update(time.time() - end_time)
    #compute loss
    loss  = criterion(output, labels)
    torch.cuda.synchronize()
    loss_time.update(time.time() - end_time)
    #make predictions
    predictions = (output.detach() > 0).int()
    torch.cuda.synchronize()
    prediction_time.update(time.time() - end_time)
    #calculate accuracy
    accuracy = (labels.detach().int() == predictions).sum().item()/batch_size
    torch.cuda.synchronize()
    accuracy_time.update(time.time() - end_time)
    #optimize
    optimizer.zero_grad()
    torch.cuda.synchronize()
    optimizer_time.update(time.time() - end_time)
    loss.backward() ### backward takes 100x longer than forward in torch 1.4.0
    torch.cuda.synchronize()
    backward_time.update(time.time() - end_time)
    optimizer.step()
    torch.cuda.synchronize()
    batch_time.update(time.time() - end_time)
    #restart the timer
    end_time = time.time()

#Then print the average measures for the epoch
print(
    Time ({batch_time.avg:.3f})\t'
    'Backward  ({backward_time.avg:.3f})\t'
    'Optimizer  ({optimizer_time.avg:.3f})\t'
    'Accuracy  ({accuracy_time.avg:.3f})\t'
    'Loss  ({loss_time.avg:3f})\t'
    'Forward  ({forward_time.avg:3f})\t'
    'Grad  ({grad_time.avg:3f})\t'
    'Cuda  ({cuda_time.avg:.3f})\t'
    'Data  ({data_time.avg:.3f})\t''.format(
    backward_time=backward_time,
    optimizer_time=optimizer_time,
    forward_time=forward_time,
    accuracy_time=accuracy_time,
    loss_time=loss_time,
    grad_time=grad_time,
    batch_time=batch_time,
    cuda_time=cuda_time,
    data_time=data_time))

These are my recorded measures when running on torch 1.4.0:
Time (16.885), Backward (16.885), Optimizer (0.476), Accuracy (0.476), Loss (0.476), Forward (0.457), Grad (0.348), Cuda (0.348), Data (0.277)

Then when run on the same machine using torch 0.4.1:
Time (0.524) , Backward (0.520), Optimizer (0.510), Accuracy (0.510), Loss (0.508), Forward (0.419565), Grad (0.405319), Cuda (0.405), Data (0.229)

To add some more information about the system:
GPU GTX 1080. Cuda 10.2, NVCC 9.2. I also get similar benchmarks when running on
GPU RTX 2070, CUDA 10.2, NVCC 10.2.

Thanks for the code. Could you post your model and the input shapes, so that we could take a look, why the backward is 100x slower in 1.4.0?

Hey @prtblck , I’m somewhat conflicted about publicly sharing my source code. Is there a way I could share it only with Pytorch developers? I personally want to contribute to the community, but would not like to expose intellectual property.

Inspecting my code, something that might be impacting the computation of the gradients is that my model’s forward method contains conditional statements, assertions and try/except blocks. I’m guessing some of these are might be impacting the performance.

That is understandable.
Would it be possible to “simulate” your approach, e.g. by adding some conditions using just random numbers?

I don’t really know how, and the model is ~300 lines of code. I’m do am planning to release the code with the publication. Shall we put a hold to this till then?

Sure. Make sure you publish the paper and ping me here again once the code is public. :wink:

PS: Not sure, if you’re already using it, but add torch.backends.cudnn.benchmark = True at the beginning of your script. This will run benchmarks for each new input shape and select the fastest cudnn kernel for your workload. Note that the first iteration will be slower due to benchmarking (and each “first” iteration for new input shapes).

Hi @ptrblck,
Yes, I’m using the torch.backends.cudnn.benchmark = True and also replaced all the try/except with conditionals and that made no difference in the performance.

Another piece of information that might be useful to determine the sink in performance is that the model I’m using is a transfer learning model which was used trained on pytorch 0.4.1. Maybe there could be a performance issue with using weights trained with a previous pytorch version?

I still haven’t been able to publicly release the code, but I can grant you read access.Do you think that might help resolve this issue?

That shouldn’t be the case, as the state_dict would only hold the parameters and no operations etc.

No, please don’t share the code with “strangers” without a proper agreement. :wink:
If you could write a “fake model” with a similar architecture without releasing any of your research, feel free to post it here. Otherwise, I’ll just wait until you are able to release the code to have a look at it.

Hi, @ptrblck. I actually figured out a way to share the model without disclosing my custom architecture. I’m basing my model on the ITracker model (https://github.com/CSAILVision/GazeCapture/blob/master/pytorch/ITrackerModel.py) which is publicly available on the repo. https://github.com/CSAILVision/GazeCapture/tree/master/pytorch. This model suffers from the same issues as mine.

I was able to pin-point that the change which introduced the performance issues was introduced in the torch version 1.3.0.

Hi @ptrblck, I understand if this is not a pressing issue to resolve currently by the development team, but I wonder if there has been any progress in understanding the root cause of the issue and if I can help somehow. :slight_smile:

@ptrblck I think that the issue might be because of using the layer:

nn.CrossMapLRN2d

Thanks for the isolation.
What were the performance numbers for this module before the 1.3 release and after, as you’ve mentioned that the regression was introduced there?
Since this layer doesn’t use any 3rd party libs such as cudnn, native PyTorch functions seem to have introduced the regression (source code).

Sorry, what is it that you are asking? Version 1.3.0 introduces the performance regression, version such as 1.2.0 and before have good performance for the nn.CrossMapLRN2d module.

I was asking for your profiling results. I.e. since you already isolated this layer to introduce the regression, what was the performance in 1.2 vs. 1.3. How did you profile it and which device are you using?