The cpu() call is slow for a model's output tensor only

I have a model where the model’s output tensor has a shape of [32, 2, 128, 128]. When I copy this tensor object to CPU memory it takes about 63 milliseconds.

When I create a tensor of the same shape [32, 2, 128, 128] using torch.randn() on the same GPU, then the copy of this tensor object to CPU memory takes about 1 milliseconds.

What options are there to increase the speed of cpu() for a model’s output tensor? I already tried various things like nograd. I also tried to copy the model’s output tensor to a different unused GPU. I get the same slow speed when I call cpu() from this GPU. It looks like there is some expensive parsing going on for a model’s output tensor.

Does anyone have more insights, and how I could possible improve it?


How are you measuring the time here? If you are not adding torch.cuda.synchronize() calls before the output tensor copy, you could be unwittingly including the cost of previous CUDA operations (e.g., that of running of a model if that is what occurs just before the copy) that have been dispatched but not completed yet in your measurements.

Yes, I’m also using torch.cuda.synchronize() before I measure the time. And there are no CUDA operations or models on the other GPU where I copied the model’s output tenser to. This test showed that the cpu() call is also slow from this GPU.

That’s interesting, could you post a minimal code snippet that reproduces this issue?

It turned out that the first cpu() call after each batch inference is always slow, in my case often 150 ms. This is even the case if I move the model’s output tensor to another unused GPU and then to CPU memory.

My current solution is to save the model’s output tensor for each batch in a list. Once I’m done with the inferences I loop through the list and call cpu() for each batch. The first cpu() call after the last inference is still slow but the remaining cpu() calls take now between 1 to 5 ms. This improved the total processing time of my images significantly (total inference time did not change).