Putting a small tensor to the GPU is adding a huge bottleneck

I was profiling my code and noticed a list comprehension taking a long time. The list comprehension was in the following method:

    def get_targets(self, next_states):
        next_states = torch.stack([next_states[:, i, :].unsqueeze(dim=0).repeat(self.ensemble_size, 1, 1) for i in range(self.ensemble_size)])
        with torch.no_grad():
            targets_ = self.critic_target.batched_forward_(next_states).max(dim=-1)[0].mean(dim=-1)
            idx = [torch.randperm(self.ensemble_size)[:2] for _ in range(self.ensemble_size)]
            targets = torch.cat([targets_[i][:, idx[i]].min(dim=-1, keepdim=True)[0] for i in range(self.ensemble_size)], dim=1)
        return targets

the list comprehension in question is the line which assigns the targets variable. So, I thought I would try and make this quicker by using the following:

    def get_targets(self, next_states):
        next_states = torch.stack([next_states[:, i, :].unsqueeze(dim=0).repeat(self.ensemble_size, 1, 1) for i in range(self.ensemble_size)])
        with torch.no_grad():
            targets_ = self.critic_target.batched_forward_(next_states).max(dim=-1)[0].mean(dim=-1)
            idx = torch.stack([torch.randperm(self.ensemble_size)[:2] for _ in range(self.ensemble_size)]).to(self.device)
            targets = targets_.gather(-1, idx.unsqueeze(dim=1).repeat(1, self.batch_size, 1)).min(dim=-1)[0].transpose(0, 1)
        return targets

Now, after profiling, the overall time is slightly faster but there is still a huge bottleneck in putting the idx tensor onto the GPU, and this method is called at the same frequency as the get_targets method. What is most confusing is that I put much larger tensors onto the GPU in other parts of the file and they are much quicker. Does anybody know what could be causing this?

For reproducibility, next_states is a tensor of sizes (512, 25, 9), replace the line targets_ = self.critic_target.batched_forward_(next_states).max(dim=-1)[0].mean(dim=-1) with targets_ = torch.randn((25, 512, 25)), and self.ensemble_size and self.batch_size are 25 and 512, respectively.

Below, I will also add a snapshot of the profiling results to demonstrate how the .to() method starts to be dominated by this single line. The LHS is the code after making the change. We can see that we get rid of the list comprehension that was dominating the time, but with only 10k more calls to the .to() method (which checks out with how many times the above method was called during profiling) we are now almost 10x slower. As I mentioned previously, I don’t know why this would be the case since I put much larger tensors on the GPU using this code:

    def get_batch(self):
        batch = self.memory.sample()
        states, actions, rewards, next_states, dones = batch
        states = states.view(self.batch_size, self.ensemble_size, -1).to(self.device)
        actions = actions.view(self.batch_size, self.ensemble_size, -1).to(self.device)
        rewards = rewards.view(self.batch_size, self.ensemble_size).to(self.device)
        next_states = next_states.view(self.batch_size, self.ensemble_size, -1).to(self.device)
        dones = dones.view(self.batch_size, self.ensemble_size).to(self.device)

        return states, actions, rewards, next_states, dones

where we can assume that the dimensions are much larger than the small idx tensor. In case it is relevant, states, next_states, rewards are all float tensors whilst dones and actions are long tensors.