How to implement ensemble efficiently?

I would like to train an ensemble of same network architectures. My question is how to define it in such a way that the backprop is parallelized automatically by PyTorch.

Currently, I’m defining it this way.

def make_net(network_specs):
    # return nn.Module

class Ensemble(nn.Module):
    def __init__(self, network_specs, ensemble_size):
        super().__init__()
        self.model = nn.ModuleList([make_net(network_specs) for _ in range(ensemble_size)])
        
    def forward(self, x):
        return torch.cat([self.model[i](x[i]) for i in range(ensemble_size)])

However, backprop of cat doesn’t seem to be parallelized.

What would be a better or cleaner way to implement an ensemble?

Thanks.