Checking tensor is all 0's tensor.sum().data[0] == 0 extremely slow

I recently profiled my model code and one surprising thing that came up was how slow checking if a tensor is all 0’s is. More specifically,

my_tensor.sum().data[0] == 0

takes upwards of 0.01 seconds per execution which can add up. I’m guessing this is because ‘my_tensor’ is a variable on GPU and must be converted to CPU before doing the boolean check. My question is whether there is a faster way to check whether ‘my_tensor’ is all 0’s (i.e. turn into a Boolean True/False) statement. Assume ‘my_tensor’ needs to be GPU variable.

I’m thinking of doing,

len(torch.nonzero(my_tensor)) < 1

Theoretically this should work but it seems a bit hacky.

Hi,

Did you tried calling torch.cuda.synchronize() juste before this line and check if it is not the sync that takes all the time? Remember that the cuda api is asynchronous and only wait for the computation to be done when you actually get them back to the cpu.

You were right, synchronization was the problem. For this reason my proposed solution didn’t work. I guess there’s really no way of getting around this overhead since my code checks the condition to decide whether to break out of a computation for-loop. Maybe I’m wrong?

for t in range(10)
            self.RNN(stuff)
            if my_tensor.float().sum().data[0] == 0:
                break

By the way, where can I read more about how the async stuff works in Pytorch? A quick google didn’t bring useful stuff up.

Unfortunately, if the result of a computation is needed for control flow there is not much you can do :confused:
You can read about cuda semantics in the doc.

1 Like

Hi, I profile my scripts and I find that item occupy almost half time!

I read the cuda semantics and find it notes that “PyTorch automatically performs necessary synchronization when copying data between CPU and GPU or between two GPUs.”, so the item function will copy data from GPU to CPU when the Tensor is in GPU? Is there any other situation that will activate s ynchorization? Thank you!

Hi,

I am not a huge expert but the most common ones are moving the tensor to the cpu with type change. Trying to print / get a value as a python number. Forced deletion with empty_cache().

1 Like