Model inference time depends on state of fed numpy array (contiguous or not)

We encountered a performance problem in our code today. After some debugging it turns out that it has something to do with the numpy array fed into the PyTorch as_tensor function. If we make a copy of the numpy array before feeding, the code runs much faster than when feeding the original numpy array.

Here a short Python code which shows the behavior:

import torch
import numpy as np
from time import time


for i in range(100):
    # create image (numpy array)
    s = 1000
    x = np.zeros([s, s, 3])

    # crop image
    x = x[:, 1:-1]

    # create batch
    x = x[None].transpose(0, 3, 1, 2)

    # every 2nd iteration, create copy of numpy array
    create_copy = bool(i % 2)
    if create_copy:
        x = x.copy()

    # copy to gpu, apply some op, copy back
    t0 = time()
    x = torch.as_tensor(x).to('cuda')
    x = x+1
    x = x.to('cpu')
    t1 = time()

    print(f"{'copy' if create_copy else 'no copy'}: {1000*(t1-t0):.1f}ms")

The output looks something like this, where one sees that the copy is much faster:

...
no copy: 23.9ms
copy: 13.9ms
...

While creating the copy seems to solve the issue, I would still be interested what is going here. I could not locate the problem. Is it due to the view which is created by the slice operator? Or due to the non-contiguous memory?

I think you would be seeing the same performance, if you start the timer before applying the manual x.copy() in every second iteration.
If I’m not mistaken, torch.as_tensor(x).to('cuda') would trigger a copy internally, as it would need to create a clone of the underlying numpy array data and push it to the device. Since your manual copy() operation already triggered this copy and you are starting the timer afterwards, it looks as if it would come for “free”.

After moving t0 right before the if create_copy condition, I get the same iteration speed.

Hi,

thanks a lot for your reply.
You’re right - it seems I simplified the code too much and thought it still captures the issue, but it does not.

So let me try again. I have to add a bit more code to reproduce it: a seq model with 3 conv ops with dilation and padding. I hope the way I time the code is ok (I do a cuda sync before querying the time).

import torch
import torch.nn as nn
import numpy as np
from time import time


model = nn.Sequential(nn.Conv2d(in_channels=3,
                                out_channels=64,
                                kernel_size=3,
                                padding=2,
                                dilation=2),
                      nn.Conv2d(in_channels=64,
                                out_channels=128,
                                kernel_size=3,
                                padding=4,
                                dilation=4),
                      nn.Conv2d(in_channels=128,
                                out_channels=128,
                                kernel_size=3,
                                padding=2,
                                dilation=2))
model.eval()
model.to('cuda')

for i in range(10):
    torch.cuda.synchronize()
    t0 = time()

    # create image (numpy array)
    x = np.zeros([720, 1280, 3], dtype=np.uint8)

    # cut out part of the image
    x = x[1:-1, 1:-1]
    x = x.copy()  # and if I remove this line then it is always fast

    # create batch
    x = x[None].transpose(0, 3, 1, 2)

    # copy at each 2nd iteration
    make_copy = bool(i % 2)
    print('copy' if make_copy else 'no copy')
    if make_copy:
        x = x.copy()

    x = torch.as_tensor(x).to('cuda')
    x = x.float()
    x = x/255-0.5

    print('is_contiguous:', x.is_contiguous())

    with torch.no_grad():
        y = model(x)

    torch.cuda.synchronize()
    t1 = time()

    print('dt:', (t1-t0)*1000)
    print()

This gives me:

no copy
is_contiguous: False
dt: 2819.4923400878906

copy
is_contiguous: True
dt: 78.82118225097656

So there is a huge gap on my GPU between copying and not copying. What is interesting is that if the Torch tensor is contiguous everything runs fast, and otherwise it runs slow (so as_tensor creates a non-contiguous tensor from a non-contiguous numpy array!?). I was not able to reproduce with only 2 conv ops stacked, but starting with 3 and with a large enough number of channels I could reproduce (e.g. with all output channels set to 64, I could not reproduce).

Tested on 2080TI and 3090, and with PyTorch 1.8.1+cu111.

P.S.: changed the title, as it is not the function as_tensor that is slow, but somehow the model inference.

Thanks for the update. It seems you are now trying to profile the GPU execution time for the posted model (not the data transfer). Since CUDA operations are executed asynchronously, you would have to synchronize the code before starting and stopping the timer via torch.cuda.synchronize(). Did you add these syncs already or are you reusing the previous code?

Yes, measuring the whole model execution time. In my first post I thought the data transfer is slow, but it seems to be the model execution. I’m using the code from post 3. So yes, I’m having a sync before measuring t0 as follows (and the same for t1)

    torch.cuda.synchronize()
    t0 = time()

Hi,

@ptrblck any update on this? Can you reproduce? If you need more information please let me know.

push.

Even though we bypassed the problem by making sure numpy array is contiguous before converting to tensor, it would still be interesting what causes the drop in performance.

Sorry for not checking the updated code. Let me take another look at it.

EDIT: Unfortunately, I cannot reproduce the large difference in speed using your code snippet and get:

copy
is_contiguous: True
dt: 28.643369674682617

no copy
is_contiguous: False
dt: 27.836084365844727

copy
is_contiguous: True
dt: 37.908315658569336

no copy
is_contiguous: False
dt: 28.136253356933594

copy
is_contiguous: True
dt: 28.306961059570312

no copy
is_contiguous: False
dt: 27.74810791015625

copy
is_contiguous: True
dt: 29.030799865722656

no copy
is_contiguous: False
dt: 28.918027877807617

copy
is_contiguous: True
dt: 29.46949005126953

Note that the very first CUDA operation would initialize the CUDA context and should thus be removed from the profile. Also, warmup iterations are usually also needed, as your GPU might be in an IDLE state and even after the context was already initialized, the first op might use a bit more time.
Could you add more iterations to it and check, if the number stabilize?

added more iterations, but behavior is still the same.

Are you constantly seeing the ~40x increase in each odd iteration?
If so, could you update to 1.9.0 and rerun the code, as I’m having trouble reproducing the issue?

Are you constantly seeing the ~40x increase in each odd iteration?

yes

If so, could you update to 1.9.0 and rerun the code, as I’m having trouble reproducing the issue?

  • 1.9.0+cu111: both versions are fast (same behavior you observed)
  • 1.8.1+cu111: here I have the reported performance gap

So, it seems the venv must contain the PyTorch LTS version to see this issue, which I installed via:
pip3 install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html