I am implementing a multi-head network. (This is to implement multi-head DQN, a specific reinforcement learning method, but this doesn’t really matter here.)
My network has the following architecture:
input -> 128x (separate fully connected layers) -> output averaging
I am using a ModuleList to hold the list of fully connected layers. Here’s how it looks at this point:
class MultiHead(nn.Module):
def __init__(self, dim_state, dim_action, hidden_size=32, nb_heads=1):
super(MultiHead, self).__init__()
self.networks = nn.ModuleList()
for _ in range(nb_heads):
network = nn.Sequential(
nn.Linear(dim_state, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, dim_action)
)
self.networks.append(network)
self.cuda()
self.optimizer = optim.Adam(self.parameters())
Then, when I need to calculate the output, I use a for ... in
construct to perform the forward and backward pass through all the layers:
q_values = torch.cat([net(observations) for net in self.networks])
# skipped code which ultimately computes the loss I need
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
This works! But I am wondering if I couldn’t do this more efficiently. By doing a for...in
, I am actually going through each separate FC layer one by one, and as a result the training time grows with the number of FC layers.
Can this operation could be done in parallel?
This is similar to this (unanswered) question: Parallel execution of modules in nn.ModuleList