I am implementing a multi-head network. (This is to implement multi-head DQN, a specific reinforcement learning method, but this doesn’t really matter here.)
My network has the following architecture:
input -> 128x (separate fully connected layers) -> output averaging
I am using a ModuleList to hold the list of fully connected layers. Here’s how it looks at this point:
class MultiHead(nn.Module): def __init__(self, dim_state, dim_action, hidden_size=32, nb_heads=1): super(MultiHead, self).__init__() self.networks = nn.ModuleList() for _ in range(nb_heads): network = nn.Sequential( nn.Linear(dim_state, hidden_size), nn.Tanh(), nn.Linear(hidden_size, dim_action) ) self.networks.append(network) self.cuda() self.optimizer = optim.Adam(self.parameters())
Then, when I need to calculate the output, I use a
for ... in construct to perform the forward and backward pass through all the layers:
q_values = torch.cat([net(observations) for net in self.networks]) # skipped code which ultimately computes the loss I need self.optimizer.zero_grad() loss.backward() self.optimizer.step()
This works! But I am wondering if I couldn’t do this more efficiently. By doing a
for...in, I am actually going through each separate FC layer one by one, and as a result the training time grows with the number of FC layers.
Can this operation could be done in parallel?
This is similar to this (unanswered) question: Parallel execution of modules in nn.ModuleList