Custom torch.autograd.Function seems slows down the speed, especially the backward pass takes way long. For simple verification, I used custom relu function (https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html) instead of relu and everything else is the same. However the time spent to run 100 iterations of 128 batch (ResNet56 on Cifar10) is (0.251 vs 0.002). It seems like 100x slower. I used 0.4.0 PyTorch version. Can you please help me?
[Edit: for 0.3.1 PyTorch, for both the cases, the speeds are about the same (0.077 vs 0.065). But I want to use 0.4.0 version to avoid “TypeError: ‘module’ object is not callable” by adding self.scale = nn.Parameter(torch.tensor([10.0]), requires_grad=True) in nn.Module. ]
class myReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
class PreActBottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(PreActBottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
#out = self.relu(out)
out = myReLU.apply(out)
if self.downsample is not None:
residual = self.downsample(out)
out = self.conv1(out)
out = self.bn2(out)
#out = self.relu(out)
out = myReLU.apply(out)
out = self.conv2(out)
out = self.bn3(out)
#out = self.relu(out)
out = myReLU.apply(out)
out = self.conv3(out)
out += residual
return out