Custom function slows down the speed significantly (0.4.0 PyTorch)

Custom torch.autograd.Function seems slows down the speed, especially the backward pass takes way long. For simple verification, I used custom relu function (https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html) instead of relu and everything else is the same. However the time spent to run 100 iterations of 128 batch (ResNet56 on Cifar10) is (0.251 vs 0.002). It seems like 100x slower. I used 0.4.0 PyTorch version. Can you please help me?

[Edit: for 0.3.1 PyTorch, for both the cases, the speeds are about the same (0.077 vs 0.065). But I want to use 0.4.0 version to avoid “TypeError: ‘module’ object is not callable” by adding self.scale = nn.Parameter(torch.tensor([10.0]), requires_grad=True) in nn.Module. ]

class myReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)

@staticmethod
def backward(ctx, grad_output):
    input, = ctx.saved_tensors
    grad_input = grad_output.clone()
    grad_input[input < 0] = 0
    return grad_input

class PreActBottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
    super(PreActBottleneck, self).__init__()
    self.bn1 = nn.BatchNorm2d(inplanes)
    self.relu = nn.ReLU(inplace=True)
    self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn3 = nn.BatchNorm2d(planes)
    self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
    self.downsample = downsample
    self.stride = stride
    
   
def forward(self, x):
    residual = x

    out = self.bn1(x)
    #out = self.relu(out)
out = myReLU.apply(out)

    if self.downsample is not None:
        residual = self.downsample(out)

    out = self.conv1(out)

    out = self.bn2(out)
    #out = self.relu(out)
    out = myReLU.apply(out)
    out = self.conv2(out)

    out = self.bn3(out)
    #out = self.relu(out)
    out = myReLU.apply(out)
    out = self.conv3(out)

    out += residual

    return out

your benchmarking is likely wrong.

Make sure you do such benchmarking for GPUs:

torch.cuda.synchronize() # finish all previous cuda calls
tm = time.time()
out = model(input)
torch.cuda.synchronize() # finish all cuda calls
print(time.time() - tm)

(PyTorch 0.4.0) Yes, when I apply your benchmarking for forward pass (as in your code) gives similar time for relu and custom rule (0.014 vs 0.014).
However, the time is slower for custom relu in both forward and backward. (0.051 vs 0.218). It seems like the backward pass takes very long. For (PyTorch 0.3.1.post2), the time for relu and custom relu in both forward and backward is (0.058 vs 0.056) with the below benchmarking…Can you please help me on this? Thank you.

    torch.cuda.synchronize()
    tm = time.time()
    output = model(input_var)

         
    loss = criterion(output, target_var)
    loss.backward()
    
    
    torch.cuda.synchronize()
    print(time.time() -tm)
2 Likes

I also met this problem. Did you solve it?

No, I downgrade to 0.3.1 version which doesn’t have this problem :frowning:

1 Like

To anyone who has the same problem:

grad_input[input < 0] = 0

is a slow operation. (Although it is used in the official tutorial)
Changing it to:

grad_input = grad_input * (input > 0).float()

works for me.