Randomly select items from two equally sized tensors

Assume that we have two equally sized tensors of size batch_size * 1. For each index in the batch dimension we want to choose randomly between the two tensors. My solution was to create an indices tensor that contains random 0 or 1 indices of size batch_size and use those to index_select from the concatenation of the two tensors. However, to do so I had the “view” that cat tensor and the solution ended up to be quite “ugly”:

import torch

bs = 8
a = torch.zeros(bs, 1)
print("a size", a.size())
b = torch.ones(bs, 1)

c = torch.cat([a, b], dim=-1)
print(c)
print("c size", c.size())

# create bs number of random 0 and 1's
indices = torch.randint(0, 2, [bs])
print("idxs size", indices.size())
print("idxs", indices)

# use `indices` to slice the `cat`ted tensor
d = c.view(1, -1).index_select(-1, indices).view(-1, 1)
print("d size", d.size())
print(d)

I am wondering whether there is a prettier and, more importantly, more efficient solution.

You can do it this way:

import torch

bs = 8
a = torch.zeros(bs, 1)
print("a size", a.size())
b = torch.ones(bs, 1)

idx = torch.randint(2 * bs, (bs,))

d = torch.cat([a, b])[idx] # [bs, 1]

d is a [bs, 1 ] tensor containing a random sample of both batchs, with possible sample duplicates, which is OK for most statistical applications, it is just a sort of bootstraping.

According to this link, this may be the fastest way.

import torch

bs = 8
a = torch.zeros(bs, 1)
b = torch.ones(bs, 1)
c = torch.cat([a, b], dim=-1)
choices_flat = c.view(-1)
# index = torch.randint(choices_flat.numel(), (bs,))
# or if replace = False
index = torch.randperm(choices_flat.numel())[:bs]
select = choices_flat[index]

print(select)