Backward pass 6x slower in CPU

Hello, I am building a model parallel CNN using CPU/GPU cluster. In both cases backward pass is about 6x slower than the forward pass. Particularly in CPU, Autograd profiler mainly lists convolution operations as the bottleneck. Forward functions of the modules are somewhat complex and uses lot of if statements and indexing similar to following

a = tensor.ones(2,3)
b = a[:, 0,1,2]

As there’s not even a permutation in the second dimension, I guess such indexing has no overhead. As I read, the backward pass is about 2x slower and that’s why I thought of raising a question on this. Because of the model parallel nature, weights in a single layer are divided across processes/nodes. So a single convolution layer is represented as a combination of several convolutions which span across nodes. Maybe this setting has an impact on performance. Following is the autograd profile for forward and backward pass in CPU. Any help on this regard is highly appreciated. Thanks

Profile Data

Using PyTorch v1.6, MKL-DNN v1.2.0. Here ‘Connect’ ops are the functions used for communication between nodes.

Forward Pass

 --------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                        Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls  
--------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
mkldnn_convolution          36.62%           120.430ms        73.66%           242.227ms        5.767ms          42               
C_Connect                   28.26%           92.930ms         31.67%           104.140ms        4.734ms          22               
S_Connect                   18.63%           61.264ms         20.19%           66.377ms         8.297ms          8                
native_batch_norm           4.63%            15.231ms         10.89%           35.819ms         895.465us        40               
_cat                        3.58%            11.760ms         4.21%            13.851ms         364.492us        38               
add                         1.08%            3.536ms          2.51%            8.261ms          56.585us         146              
empty                       0.78%            2.562ms          0.79%            2.603ms          7.912us          329              
mul                         0.57%            1.861ms          1.16%            3.822ms          238.906us        16               
sub                         0.55%            1.796ms          1.12%            3.692ms          230.766us        16               
op_Conv2D                   0.47%            1.532ms          37.70%           123.974ms        5.904ms          21               
is_leaf                     0.41%            1.341ms          0.50%            1.648ms          0.659us          2500             
leaky_relu                  0.39%            1.277ms          0.84%            2.759ms          153.299us        18   

Backward Pass

 -----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
Name                                 Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     Number of Calls  
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  
slow_conv_dilated2d                  20.68%           1.287s           69.93%           4.352s           37.949us         114688           
slow_conv_transpose2d                13.49%           839.662ms        27.88%           1.735s           43.386ms         40               
copy_                                11.12%           691.946ms        11.37%           707.678ms        2.842ms          249              
mkldnn_convolution                   10.04%           624.944ms        20.11%           1.252s           41.719ms         30               
size                                 6.53%            406.650ms        6.53%            406.650ms        0.243us          1674695          
op_Conv2DBackward                    5.99%            372.616ms        95.47%           5.942s           282.958ms        21               
slice                                4.58%            285.351ms        8.29%            515.909ms        2.246us          229714           
_cat                                 4.14%            257.976ms        5.55%            345.569ms        12.342ms         28               
_convolution                         2.87%            178.902ms        77.67%           4.834s           117.910ms        41               
empty                                2.75%            171.139ms        2.78%            172.734ms        1.499us          115224           
fill_                                2.39%            149.045ms        2.39%            149.047ms        2.592us          57500            
as_strided                           2.36%            147.065ms        2.36%            147.065ms        0.631us          233126           
select                               2.33%            144.808ms        4.15%            258.508ms        2.193us          117878           
narrow                               1.96%            121.740ms        8.09%            503.789ms        4.384us          114916           
C_ConnectBackward                    1.95%            121.317ms        2.45%            152.702ms        6.941ms          22               
contiguous                           1.40%            86.923ms         1.56%            97.384ms         0.424us          229539

Do you have an idea, why the number of calls to the convolution in the backward pass is ~3000x higher than in the forward pass?
Could the mentioned splitting of the convolution create this increase in the calls?

Thank you @ptrblck for pointing that out. Actually it happens when I use torch.nn.grad.conv2d_input, torch.nn.grad.conv2d_weight in the backward pass of a custom conv op. with stride=1 number of calls to the mkldnn_convolution operation is normal but backward pass is 100x slow. When stride=2 both number of calls to slow_conv_dilated2d and backward time increases heavily. Hopefully you can reproduce this with the following code. Maybe I am doing something wrong in the custom convolution operation. Thanks.

import torch
import torch.nn as nn
import torch.nn.functional as F

class op_Conv2D(torch.autograd.Function):
  
  @staticmethod
  def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    ctx.save_for_backward(input, weight, bias)
    ctx.stride = stride
    ctx.padding = padding
    ctx.groups = groups
    ctx.dilation = dilation
    out = F.conv2d(input, weight, bias, stride, padding, dilation, groups)
    return out
    
  @staticmethod
  def backward(ctx, grad_output):
    input, weight, bias = ctx.saved_tensors
    grad_input = grad_weight= grad_bias = None

    if ctx.needs_input_grad[0]:
      grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
        
    if ctx.needs_input_grad[1]:
      grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
            
    if bias is not None and ctx.needs_input_grad[2]:
      grad_bias = grad_output.sum((0,2,3))
        
    if bias is not None:
      return grad_input, grad_weight, grad_bias, None, None, None, None
    else:
      return grad_input, grad_weight, None, None, None, None, None

w = nn.Parameter(torch.randn(32, 32, 3, 3, device=torch.device("cpu")))
b = nn.Parameter(torch.zeros(32, device=torch.device("cpu"))) 

data = torch.rand(32, 32, 3, 3)

with torch.autograd.profiler.profile(use_cuda=False) as prof:
  x_act = op_Conv2D.apply(data,w, b, 2, 0, 1, 1)
print("forward pass F.conv2d --- \n", prof.key_averages().table(sort_by="self_cpu_time_total"))

with torch.autograd.profiler.profile(use_cuda=False) as prof:
  x_act.sum().backward()
print("backward pass F.conv2d --- \n", prof.key_averages().table(sort_by="self_cpu_time_total"))


conv_block = nn.Conv2d(32, 32, 3, stride=2).cpu()

with torch.autograd.profiler.profile(use_cuda=False) as prof:
  x_act = conv_block(data)
print("forward pass nn.Conv2d --- \n", prof.key_averages().table(sort_by="self_cpu_time_total"))

with torch.autograd.profiler.profile(use_cuda=False) as prof:
  x_act.sum().backward()
print("backward pass nn.Conv2d --- \n", prof.key_averages().table(sort_by="self_cpu_time_total"))

I can’t see anything obviously wrong and am unfortunately not familiar with the mkldnn implementation. :confused:

There are several issues related to this in git and think its not yet solved. I also opened a issue #37308 on this recently, so I will post this there. Thanks again for pointing out the culprit.

To avoid using nn.grad.conv2d_input, nn.grad.conv2d_weight as their performance is problematic in the backward pass, I tried using nn.Conv2d with few modifications. My use case is using only subset of the weight tensor’s input channels and masking before applying convolution. For example,

w_true = nn.Parameter(torch.randn(32, 16, 3, 3))
w_modfd = torch.rand_like(w_true) * w_true  #masking
w_modfd = w_modfd[:, :8] #using only subset of input channels

dummy_conv = nn.Conv2d(1,1,1)

while True:
  input = torch.randn(128, 8, 28, 28)
  dummy_conv.weight.data = w_modfd

  out = dummy_conv(x)

As the nn.Conv2d modules performance is as expected, this runs smoothly and accuracy is not also affected as I see. But are there any obvious pitfalls in this method with regard to memory usage, correctness and etc.? One thing I noticed is w_true.grad is not calculated and if we require the gradient of w_true, we should access it through dummy_conv.weight.grad.

The usage of the .data attribute is not recommended, as it might have side effects, such as wrong gradient calculation.
If you want to set w_modfd as the filter kernels for dummy_conv, you could assign it as a new nn.Parameter outside of the loop.