I have tried with CUDA streams but I still see a big slowdown when scaling up the number of heads:
streams = [torch.cuda.Stream() for _ in range(nb_heads)]
torch.cuda.synchronize()
losses = []
net_idx = 0
for net in self.networks:
with torch.cuda.stream(streams[net_idx]):
q_values = net(observations)
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
next_q_values = net(next_observations)
next_q_values = next_q_values.max(1)[0]
expected_q_values = rewards + gamma * next_q_values * (1 - terminals)
losses.append((q_values - expected_q_values.detach()))
net_idx += 1
torch.cuda.synchronize()
loss = torch.cat(losses).pow(2).mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()