Conditional indexing is extremely slow?

I’m implementing something like ReLU:

class My_ReLU(nn.Module):

    def __init__(self):
        super(My_ReLU, self).__init__()

    def forward(self, x):

        x[x < 0] *= 0

        return x

I found my code is extremely slow compared to pytorch officially ReLU mainly caused by x[x < 0] *= 0.

So is there faster way for conditional indexing?

Why is the official ReLU so fast?

The thing is that the official relu is implemented as a single kernel that does the relu.
What you do here is 3 different operations: comparison with 0, advance indexing and element wise multiplication.
If you run this on gpu, this will launch 3 kernels instead of 1 and so is expected to be slower.

This is exactly the reason why we have specialized kernels for all the common operations !

I wrote following testing script:

import torch
import torch.nn as nn
import time

class My_ReLU(nn.Module):

    def __init__(self):
        super(My_ReLU, self).__init__()

    def forward(self, x):

        x[x < 0] *= 0

        return x


official_relu = nn.ReLU().cuda()
my_relu = My_ReLU().cuda()

data = torch.rand(1000, 1000).cuda()

start = time.time()
for _ in range(100):
    official_relu(data)
print(time.time() - start)


start = time.time()
for _ in range(100):
    my_relu(data)
print(time.time() - start)

and the output is

0.0022401809692382812
0.04940199851989746

I don’t think this huge gap is caused by just 3 kernels vs 1 kernel.

Hi,

All cuda operations are asynchronous, so you should add torch.cuda.synchronize() to make sure that you measure the actual execution time and not just how long it took to queue the kernel.

Hi albanD,

So is the only way I can make my code faster is to write a cuda extension?

Avoiding the multiplication will help: x[x < 0] = 0 will be faster.
Avoiding the indexing altogether will be even better I think: x *= (x >= 0).type_as(x).