Low performance of backward in ensemble network

Hi,
I’m training a ensemble model which has N-sub similar models as branch networks and I use torch.cat to concatenate the last layer of N sub-models. (code bellow:)

from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.clip_grad import clip_grad_norm

from common import USE_CUDA, FloatTensor, NetType
from debug_utils import opts

# F/R Network definition
class Net(nn.Module):

    @staticmethod
    def _make_net(net: NetType, activation_func) -> nn.Module:
        if isinstance(net, nn.Module):
            # net is already defined, nothing else to do
            return net
        elif isinstance(net, list):
            layers = []
            for j in range(1, len(net)):
                dim_in = net[j - 1]
                dim_out = net[j]
                layers.append(nn.Linear(dim_in, dim_out))
                if j < len(net) - 1:
                    layers.append(activation_func())
            return nn.Sequential(*layers)
        else:
            raise NotImplementedError()

    def __init__(self, net, lr: float):
        super().__init__()
        self.model = Net._make_net(net, opts.hidden_activation)
        self._add_auxiliary_params()

        if USE_CUDA:
            self.cuda()
            
        self._optimizer = opts.optimizer(self.parameters(), lr=lr)

    def _add_auxiliary_params(self):
        pass

    def forward(self, x: FloatTensor) -> FloatTensor:
        return self.model(x)

    def _loss(self, outputs: FloatTensor, targets: FloatTensor):
        return F.mse_loss(outputs, targets)

    def update_parameters(self, inputs: FloatTensor, targets: FloatTensor):
        outputs = self(inputs)
        loss = self._loss(outputs, targets)
        self._optimizer.zero_grad()
        with torch.autograd.profiler.profile() as prof:
            loss.backward()
        print(prof)
        self._optimizer.step()
        
    def save(self, filename):
        torch.save(dict(
            model=self.model.state_dict(),
            optimizer=self._optimizer.state_dict()
        ), filename)

    def load(self, filename):
        checkpoint = torch.load(filename)
        self.model.load_state_dict(checkpoint['model'])
        self._optimizer.load_state_dict(checkpoint['optimizer'])
        if USE_CUDA:
            self.cuda()

class EnsembleNet(Net):
    def __init__(self, net, lr: float):
        ensemble_net = nn.ModuleList([Net._make_net(net, opts.hidden_activation)
                                      for _ in range(opts.bootstrap_size)])
        super().__init__(ensemble_net, lr)

    def forward(self, x):
        return torch.cat([self.model[i](x[i]) for i in range(opts.bootstrap_size)])

My problem is that the backward seems to be sequentially backward each sub-model (as torch.autograd.profiler.profile() in attached file below I test with N=32, after CatBackWard, it narrow and slice 32 times … )

Name                                        CPU time        CUDA time            Calls        CPU total       CUDA total
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------
torch::autograd::GraphRoot                   5.200us          3.072us                1          5.200us          3.072us
MseLossBackward                             62.497us         59.360us                1         62.497us         59.360us
mse_loss_backward                           47.498us         48.096us                1         47.498us         48.096us
CatBackward                                381.981us        404.480us                1        381.981us        404.480us
narrow                                      12.000us         12.256us                1         12.000us         12.256us
slice                                        5.500us          5.088us                1          5.500us          5.088us
.....
32 times: narrow_slice
....
narrow                                       6.999us          8.192us                1          6.999us          8.192us
slice                                        2.699us          2.080us                1          2.699us          2.080us
AddmmBackward                               88.696us         84.992us                1         88.696us         84.992us
unsigned short                               6.100us          4.096us                1          6.100us          4.096us
mm                                          24.798us         27.648us                1         24.798us         27.648us
unsigned short                               5.300us          5.152us                1          5.300us          5.152us
mm                                          25.498us         29.728us                1         25.498us         29.728us
unsigned short                               5.000us          3.072us                1          5.000us          3.072us
sum                                         21.799us         22.528us                1         21.799us         22.528us
view                                         8.299us          8.160us                1          8.299us          8.160us
torch::autograd::AccumulateGrad             18.099us         18.432us                1         18.099us         18.432us
TBackward                                   10.999us         11.264us                1         10.999us         11.264us
unsigned short                               5.700us          5.120us                1          5.700us          5.120us
torch::autograd::AccumulateGrad             13.699us         14.304us                1         13.699us         14.304us
ThresholdBackward0                          25.999us         26.624us                1         25.999us         26.624us
threshold_backward                          18.599us         19.456us                1         18.599us         19.456us
AddmmBackward                               83.496us         84.960us                1         83.496us         84.960us
unsigned short                               4.700us          4.096us                1          4.700us          4.096us
mm                                          26.399us         31.744us                1         26.399us         31.744us
unsigned short                               5.100us          3.072us                1          5.100us          3.072us
mm                                          21.899us         25.568us                1         21.899us         25.568us
unsigned short                               4.300us          2.048us                1          4.300us          2.048us
sum                                         16.800us         18.432us                1         16.800us         18.432us
view                                         6.900us          5.152us                1          6.900us          5.152us
torch::autograd::AccumulateGrad             14.999us         15.360us                1         14.999us         15.360us
TBackward                                    9.899us         10.272us                1          9.899us         10.272us
unsigned short                               4.799us          4.128us                1          4.799us          4.128us
torch::autograd::AccumulateGrad             13.199us         14.336us                1         13.199us         14.336us
ThresholdBackward0                          22.299us         22.560us                1         22.299us         22.560us
threshold_backward                          15.499us         16.384us                1         15.499us         16.384us
........

The gpu utilization is very low (<10%) and I test with N=1 and N=32 the gpu utilization is the same.


Is it possible to make sub-models can be update parallel. Thanks!