I am trying to implement an ensemble, and for my uses I only need the uncertainty measurements over the final layer’s outputs. To that end, I first implemented a network which was some Linear layers (with relu non-linearity), and then for the output I had N layers. For a single input the model would then output N predictions, as in a normal ensemble.
I then implemented an ensemble similar to this except where I essentially just had N copies of the network (but with a single output in each layer). When I trained this, on the exact same data, backprop takes almost twice as long in the first scenario, despite have an order of magnitude less parameters. Does anyone have an idea why this would be? it seems counterintuitive to me that, given less parameters and the same amount of data, that backprop would take longer.
Thanks!
Could you post the model implementations you are using (a proxy model which shows the general idea might be sufficient)?
Sure!
class Ensemble1(nn.Module):
def __init__(self, state_dim, hidden_dim, num_actions, num_heads, ensemble_size):
super(Ensemble1, self).__init__()
self.input_layer = nn.Linear(state_dim, hidden_dim)
self.h1 = nn.Linear(hidden_dim, hidden_dim)
self.h2 = nn.Linear(hidden_dim, hidden_dim)
self.ensemble = nn.ModuleList()
for j in range(ensemble_size):
self.ensemble.append(nn.ModuleList())
for i in range(num_heads):
self.ensemble[j].append(nn.Linear(hidden_dim, num_actions))
self.num_heads = num_heads
def batched_forward(self, x):
x = torch.relu(self.input_layer.forward(x))
x = torch.relu(self.h1.forward(x))
x = torch.relu(self.h2.forward(x))
ensembles = []
for idx, heads in enumerate(self.ensemble):
vals = []
for head in heads:
vals.append(head.forward(x[:, idx, :]))
ensembles.append(torch.stack(vals, dim=1))
return torch.stack(ensembles, dim=1)
class Ensemble2(nn.Module):
def __init__(self, state_dim, hidden_dim, num_actions, num_heads, ensemble_size):
super(Ensembl2, self).__init__()
self.ensemble = nn.ModuleList()
for e in range(ensemble_size):
self.ensemble.append(MLP(state_dim, hidden_dim, num_actions, num_heads))
def batched_forward(self, x):
vals = []
for idx, net in enumerate(self.ensemble):
vals.append(net.forward(x[:, idx, :]).transpose(0, 1))
return torch.stack(vals, dim=1)
Ensemble1
is the method with the slower backprop, even though the model is more lightweight than Ensemble2
. In Ensemble2, the MLP
model is essentially the same architecture as Ensemble1
but without the extra ensemble_size - 1
extra heads. batched_forward
is how I run a forward pass for the ensemble. My inputs are all vectors of size state_dim
, so when I draw a batch the tensor shape is (batch_size * ensemble_size, state_dim)
, which (after shuffling the rows) I reshape to be (batch_size, ensemble_size, state_dim)
.
The code is not executable as MLP
is undefined.
class MLP(nn.Module):
def __init__(self, state_dim, hidden_dim, num_actions, num_heads):
super(MLP, self).__init__()
self.input_layer = nn.Linear(state_dim, hidden_dim)
self.h1 = nn.Linear(hidden_dim, hidden_dim)
self.h2 = nn.Linear(hidden_dim, hidden_dim)
self.output_heads = nn.ModuleList()
for i in range(num_heads):
self.output_heads.append(nn.Linear(hidden_dim, num_actions))
self.num_heads = num_heads
def forward(self, x):
x = torch.relu(self.input_layer.forward(x))
x = torch.relu(self.h1.forward(x))
x = torch.relu(self.h2.forward(x))
vals = []
for i in range(self.num_heads):
vals.append(self.output_heads[i].forward(x))
return torch.stack(vals)
here you go
@ptrblck were you able to look into this?