I would like to train an ensemble of same network architectures. My question is how to define it in such a way that the backprop is parallelized automatically by PyTorch.
Currently, I’m defining it this way.
def make_net(network_specs): # return nn.Module class Ensemble(nn.Module): def __init__(self, network_specs, ensemble_size): super().__init__() self.model = nn.ModuleList([make_net(network_specs) for _ in range(ensemble_size)]) def forward(self, x): return torch.cat([self.model[i](x[i]) for i in range(ensemble_size)])
However, backprop of
cat doesn’t seem to be parallelized.
What would be a better or cleaner way to implement an ensemble?