Efficiently converting np.uint8 to torch.float32

I was running some data ffmpeg to torch (thru pipes) and noticed that I was doing something very naive. So I profiled, with single process, conversion between npuint to float32
The difference in CPU can be almost one order of magnitude. GPU tightens everything.

If anyone is interested.

2 Likes

Thanks for sharing the code!

Be a bit careful about the CUDA numbers, since you didn’t synchronize the calls.
CUDA calls will be executed asynchronously, which means the CPU can continue executing the code while the CUDA kernels are busy until a synchronization point is reached.
These points can be added manually using torch.cuda.synchronize() are are automatically reached, e.g. if a result from the CUDA operation is needed.
In your profiling script you might in fact just time the kernel launch times, when the .cuda() call is the first op on the tensor.

To proper time CUDA calls, you could have to synchronize before starting and stopping the timer:

torch.cuda.synchronize()
t0 = time.time()
...
torch.cuda.synchronize()
t1 = time.time()

Thank you ptrblck! You are right I didnt sync the cuda. Ill update the gist when I get to this. I also wonder whether this profiling is at all significant when using multiprocess multi gpu. I need to update my hw setup anyway which as you see is from the paleozoic.

I wrote this gist mainly because I had ndarrayuint8/255 -> torch (float32 cuda) peppered all over my code which is a waste of resources and time. I figured someone else will be doing the same stupid mistake.

1 Like

Sure and thanks again for posting!
I just skimmed through your code and will dig in it a bot later, as I’m interested in the results :wink:

great - if you look at this, there are a few things I didn’t post; float64 or float16 torch data, but it varies wildly. I dont have a good setup to deal with double or half floats so the comparison isnt valid.

I found some funny numbers to do with contiguity which I also didn’t post, because first it worked one way then another. But it went something like this,

tensor.permute(2,1,0).contiguous()
# vs
w,h,c = tensor.shape
out =  torch.zeros([1,c,h,w]])
out[0] = tensor.permute(2,1,0)
# both return a contiguous tensor

in some instances i got the latter to be 2x the speed of the ‘proper way’
although, maybe this is meaningless. Because the test wasn’t consistent.

At any rate if one is serious about the speed one should probably be working in cpp. Which I thought, maybe the way to get rid of all this uncertainty is load from buffer to a torch tensor of shape dtype device that is contiguous, directly in cpp… But thats another day.

I don’t think these operation see any significant speedup in C++, if the tensors are “large enough” (i.e. not tiny).
Usually you would just see the Python overhead, which is negligible if you have some workload on the operations.

Both methods should yield a similar timing and I would assume the second approach to be slower.
At least, that’s the result from my profiling:

x = torch.randn(64, 128, 64, 64, device='cuda')
nb_iters = 1000

# warmup
for _ in range(100):
    y = x.permute(0, 3, 2, 1).contiguous()

torch.cuda.synchronize()
t0 = time.time()

for _ in range(nb_iters):
    y = x.permute(0, 3, 2, 1).contiguous()

torch.cuda.synchronize()
t1 = time.time()

print((t1 -t0)/nb_iters)


# warmup
for _ in range(100):
    out =  torch.zeros(64, 64, 64, 128, device='cuda')
    out[:] = x.permute(0, 3, 2, 1)

torch.cuda.synchronize()
t0 = time.time()

for _ in range(nb_iters):
    out =  torch.zeros(64, 64, 64, 128, device='cuda')
    out[:] = x.permute(0, 3, 2, 1)

torch.cuda.synchronize()
t1 = time.time()

print((t1 -t0)/nb_iters)

Permute + contiguous: 0.00072478 s
Permute + copy: 0.00093216 s

1 Like