Is possible to make CatBackward computed parallelly

I use to concatenate multi output layers of network, then it result a grad_fn=. I observe from torch.autograd.profiler.profile() that the CatBackward seem to be computed sequentially.

Is possible to make CatBackward computed parallelly?. Thanks!

Name                                        CPU time        CUDA time            Calls        CPU total       CUDA total
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------
torch::autograd::GraphRoot                   5.200us          3.072us                1          5.200us          3.072us
MseLossBackward                             62.497us         59.360us                1         62.497us         59.360us
mse_loss_backward                           47.498us         48.096us                1         47.498us         48.096us
CatBackward                                381.981us        404.480us                1        381.981us        404.480us
narrow                                      12.000us         12.256us                1         12.000us         12.256us
slice                                        5.500us          5.088us                1          5.500us          5.088us
32 times: narrow_slice
narrow                                       6.999us          8.192us                1          6.999us          8.192us
slice                                        2.699us          2.080us                1          2.699us          2.080us
AddmmBackward                               88.696us         84.992us                1         88.696us         84.992us
unsigned short                               6.100us          4.096us                1          6.100us          4.096us
mm                                          24.798us         27.648us                1         24.798us         27.648us
unsigned short                               5.300us          5.152us                1          5.300us          5.152us
mm                                          25.498us         29.728us                1         25.498us         29.728us
unsigned short                               5.000us          3.072us                1          5.000us          3.072us
sum                                         21.799us         22.528us                1         21.799us         22.528us
view                                         8.299us          8.160us                1          8.299us          8.160us
torch::autograd::AccumulateGrad             18.099us         18.432us                1         18.099us         18.432us
TBackward                                   10.999us         11.264us                1         10.999us         11.264us
unsigned short                               5.700us          5.120us                1          5.700us          5.120us
torch::autograd::AccumulateGrad             13.699us         14.304us                1         13.699us         14.304us
ThresholdBackward0                          25.999us         26.624us                1         25.999us         26.624us
threshold_backward                          18.599us         19.456us                1         18.599us         19.456us
AddmmBackward                               83.496us         84.960us                1         83.496us         84.960us
unsigned short                               4.700us          4.096us                1          4.700us          4.096us
mm                                          26.399us         31.744us                1         26.399us         31.744us
unsigned short                               5.100us          3.072us                1          5.100us          3.072us
mm                                          21.899us         25.568us                1         21.899us         25.568us
unsigned short                               4.300us          2.048us                1          4.300us          2.048us
sum                                         16.800us         18.432us                1         16.800us         18.432us
view                                         6.900us          5.152us                1          6.900us          5.152us
torch::autograd::AccumulateGrad             14.999us         15.360us                1         14.999us         15.360us
TBackward                                    9.899us         10.272us                1          9.899us         10.272us
unsigned short                               4.799us          4.128us                1          4.799us          4.128us
torch::autograd::AccumulateGrad             13.199us         14.336us                1         13.199us         14.336us
ThresholdBackward0                          22.299us         22.560us                1         22.299us         22.560us
threshold_backward                          15.499us         16.384us                1         15.499us         16.384us

Shouldn’t the CatBackward itself just be returning a few views of the grad_out?
I think that isn’t a good candidate for parallelization.
Now the further processing may be a different story, and is a much more general task.

Best regards


Thank you so much for your reply!

There is any way to make these processes computed parallel?

My network is ensemble model, each branch network of my model quite small, I test model with ONE branch and 32 branch, the GPU utilization is the same and quite low (~10%). I think all last layers of branch networks then feed to ONE optimizer will make computation graph consider them as parallel process but it seem to be not as you said (CatBackward just be returning a few views of the grad_out)

def forward(self, x):
        return[self.model[i](x[i]) for i in range(number_branch_net)])

Model is a moduleList, model[i]s are similar.


But have you solved the problem for the forward?
I have not looked into this in detail, but I think I saw discussions about using cuda streams matching the forward for the backward or so, which might help.

If your ensemble is structurally similar, it might also be worth trying to vectorize across models as an alternative. Or you could increase the batch size and do the ensemble calculations in sequence as a generic solution.

Best regards