Running multiple Modules in parallel

I have tried with CUDA streams but I still see a big slowdown when scaling up the number of heads:

    streams = [torch.cuda.Stream() for _ in range(nb_heads)]
    torch.cuda.synchronize()
    
    losses = []
    net_idx = 0
    for net in self.networks:
        with torch.cuda.stream(streams[net_idx]):
            q_values = net(observations)
            q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

            next_q_values = net(next_observations)
            next_q_values = next_q_values.max(1)[0]

            expected_q_values = rewards + gamma * next_q_values * (1 - terminals)
            losses.append((q_values - expected_q_values.detach()))
        
        net_idx += 1
    
    torch.cuda.synchronize()
    loss = torch.cat(losses).pow(2).mean()
    
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()