I would like to train an ensemble of same network architectures. My question is how to define it in such a way that the backprop is parallelized automatically by PyTorch.
Currently, I’m defining it this way.
def make_net(network_specs):
# return nn.Module
class Ensemble(nn.Module):
def __init__(self, network_specs, ensemble_size):
super().__init__()
self.model = nn.ModuleList([make_net(network_specs) for _ in range(ensemble_size)])
def forward(self, x):
return torch.cat([self.model[i](x[i]) for i in range(ensemble_size)])
However, backprop of cat
doesn’t seem to be parallelized.
What would be a better or cleaner way to implement an ensemble?
Thanks.