Torch equivalent of numpy.random.choice?

Thanks!
Actually I wanted to draw k samples, and without replacement,
so I ended up doing :
perm = torch.randperm(tensor.size(0))
idx = perm[:k]
samples = tensor[idx]

(but maybe that’s not computationally efficient)

25 Likes

Well, the main advantage of numpy.random.choice is the possibility to pass in an array of probabilities corresponding to each element, which this solution does not cover. Even python’s random library enables passing a weight list to its choices() function.

2 Likes

Oh, are you looking for torch.multinomial? http://pytorch.org/docs/master/torch.html?highlight=multinomial#torch.multinomial

18 Likes

There is an issue currently opened in PyTorch’s github repo about that subject: https://github.com/pytorch/pytorch/issues/16897

1 Like

A first version of a full-featured numpy.random.choice equivalent for PyTorch is now available here (working on PyTorch 1.0.0). It includes CPU and CUDA implementations of:

  • Uniform Random Sampling WITH Replacement (via torch::randint )
  • Uniform Random Sampling WITHOUT Replacement (via reservoir sampling)
  • Weighted Random Sampling WITH Replacement (via inverse transform sampling)
  • Weighted Random Sampling WITHOUT Replacement (via this method)
6 Likes

Update: There is currently a PR waiting for review in the PyTorch’s repo.

4 Likes

This is extremely slow and unacceptable.

1 Like

In my case: values.shape = (386363948, 2), k = 190973, the following code works fairly fast, 0.1 ~ 0.2 second.

indice = random.sample(range(386363948), 190973)
indice = torch.tensor(indice)
sampled_values = values[indice]

Using torch.randperm, however, would cost more than 20 seconds.

sampled_values = values[torch.randperm(386363948)[190973]]
1 Like

Answer here !

1 Like

you can do simply by applying below logic
samples=torch.tensor([-11,5,9])
rand_choices=samples[torch.randint(len(samples),(7,))] #’'7 choices ‘’
print(rand_choices)
hope this helps!

1 Like

is there a reason why you prefer:

rand_choices=samples[torch.randint(len(samples),(7,))] #’'7 choices ‘’

vs

**perm = torch.randperm(tensor.size(0))**

?

Just curious

Hi !
To clear things up

If you want to do the equivalent of numpy.random.choice:

a = np.array([1, 2, 3, 4])
p = np.array([0.1, 0.1, 0.1, 0.7])
n = 2
replace = True

b = np.random.choice(a, p=p, size=n, replace=replace)

In pytorch you can use torch.multinomial :

a = torch.tensor([1, 2, 3, 4])
p = torch.tensor([0.1, 0.1, 0.1, 0.7])
n = 2
replace = True

idx = p.multinomial(num_samples=n, replacement=replace)
b = a[idx]

Careful, np.random.choice defaults to replace=True
But torch.multinomial defaults to replacement=False

13 Likes

In case the *num_samples* is not int type, how to deal implement the above case? Thank you!

If anyone is here looking for fast ways to select samples, I created a small comparison to time some of the popular random indexing solutions from the forums. For the comparison, I wrote small functions with the goal of generating indices to select 10% of a population. For a fair comparison, the indices are returned as GPU-tensors.

It looks like, if your population size is less than int32.MAX_VALUE, generating a random permutation on the GPU may be the fastest solution. Mileage may vary, so I’ve included my entire plotting script below so you can test it.

import numpy as np
import torch
import random
import time
import pandas as pd
import matplotlib.pyplot as plt

def rand(pop_size, num_samples):
    """Use random.sample to generate indices."""
    return torch.Tensor(random.sample(range(pop_size), num_samples)).to('cuda')

def mult_cpu(pop_size, num_samples):
    """Use torch.Tensor.multinomial to generate indices on a CPU tensor."""
    p = torch.ones(pop_size) / pop_size
    return p.multinomial(num_samples=num_samples, replacement=False).to('cuda')

def mult_gpu(pop_size, num_samples):
    """Use torch.Tensor.multinomial to generate indices on a GPU tensor."""
    p = torch.ones(pop_size, device='cuda') / pop_size
    return p.multinomial(num_samples=num_samples, replacement=False)

def perm_cpu(pop_size, num_samples):
    """Use torch.randperm to generate indices on a CPU tensor."""
    return torch.randperm(pop_size)[:num_samples].to('cuda')

def perm_gpu(pop_size, num_samples):
    """Use torch.randperm to generate indices on a GPU tensor."""
    return torch.randperm(pop_size, device='cuda')[:num_samples]

def perm_gpu_f32(pop_size, num_samples):
    """Use torch.randperm to generate indices on a 32-bit GPU tensor."""
    return torch.randperm(pop_size, dtype=torch.int32, device='cuda')[:num_samples]

def sort_rand_cpu(pop_size, num_samples):
    """Generate a random torch.Tensor (CPU) and sort it to generate indices."""
    return torch.argsort(torch.rand(pop_size))[:num_samples]

def sort_rand_gpu(pop_size, num_samples):
    """Generate a random torch.Tensor (GPU) and sort it to generate indices."""
    return torch.argsort(torch.rand(pop_size, device='cuda'))[:num_samples]

idx_fns = [rand, 
    mult_cpu, 
    mult_gpu, 
    perm_cpu, 
    perm_gpu,
    perm_gpu_f32,
    sort_rand_cpu,
    sort_rand_gpu]

pop_size = np.logspace(3, 7, 50, dtype=int)

d = []
for n_p in pop_size:
    print(f"Testing functions with {n_p} points.")
    for fn in idx_fns:
        tic = time.time()
        num_samples = int(0.1 * n_p)
        samples = fn(n_p, num_samples)
        toc = time.time()
        assert type(samples) == torch.Tensor

        d.append({
            'Population': n_p,
            'Samples': num_samples,
            'Function':fn.__name__,
            'Time': toc - tic
        })

df = pd.DataFrame(d)

fig, ax = plt.subplots()

for fn in idx_fns:
    idxs = df['Function'] == fn.__name__
    ax.plot(df['Population'][idxs], df['Time'][idxs], label=fn.__name__)

ax.set_title("Runtime of Index-Generating Methods for Randomly Selecting 10% of Population")
ax.set_xlabel('Population Size')
ax.set_ylabel('Time')
ax.set_yscale('log')
ax.legend(bbox_to_anchor=(1, 0.7))
plt.savefig('figs/randperm.png', bbox_inches='tight')
plt.show()
4 Likes

Thanks for sharing the code!
Since GPU operations are executed asynchronously, you would have to synchronize the code manually before starting and stopping the timer via torch.cuda.synchronize() to get the real execution time.
Otherwise you might be profiling the kernel launch times and blocking operations would accumulate the execution time of already running kernels.
torch.utils.benchmark provides a utility to run such comparisons and will add warmup iterations and the needed synchronizations for you.

2 Likes

Ooh, thanks! I’ll have a look and see if I can update with proper benchmarks when I have a minute.

Here are the results with proper benchmarks! It looks like generating random permutations on the GPU is still the way to go, if you want to generate indices for random selection. However, the GPU methods do not scale quite as well as it seemed before.

I couldn’t find a good way to access the benchmark results, so I settled for timeit(N).raw_times[0], which seems to give the median time spent.

import numpy as np
import torch
import random
import pandas as pd
import matplotlib.pyplot as plt
import torch.utils.benchmark as benchmark

def rand(pop_size, num_samples):
    """Use random.sample to generate indices."""
    return torch.Tensor(random.sample(range(pop_size), num_samples)).to('cuda')

def mult_cpu(pop_size, num_samples):
    """Use torch.Tensor.multinomial to generate indices on a CPU tensor."""
    p = torch.ones(pop_size) / pop_size
    return p.multinomial(num_samples=num_samples, replacement=False).to('cuda')

def mult_gpu(pop_size, num_samples):
    """Use torch.Tensor.multinomial to generate indices on a GPU tensor."""
    p = torch.ones(pop_size, device='cuda') / pop_size
    return p.multinomial(num_samples=num_samples, replacement=False)

def perm_cpu(pop_size, num_samples):
    """Use torch.randperm to generate indices on a CPU tensor."""
    return torch.randperm(pop_size)[:num_samples].to('cuda')

def perm_gpu(pop_size, num_samples):
    """Use torch.randperm to generate indices on a GPU tensor."""
    return torch.randperm(pop_size, device='cuda')[:num_samples]

def perm_gpu_f32(pop_size, num_samples):
    """Use torch.randperm to generate indices on a 32-bit GPU tensor."""
    return torch.randperm(pop_size, dtype=torch.int32, device='cuda')[:num_samples]

def sort_rand_cpu(pop_size, num_samples):
    """Generate a random torch.Tensor (CPU) and sort it to generate indices."""
    return torch.argsort(torch.rand(pop_size))[:num_samples]

def sort_rand_gpu(pop_size, num_samples):
    """Generate a random torch.Tensor (GPU) and sort it to generate indices."""
    return torch.argsort(torch.rand(pop_size, device='cuda'))[:num_samples]

idx_fns = [rand, 
    mult_cpu, 
    mult_gpu, 
    perm_cpu, 
    perm_gpu,
    perm_gpu_f32,
    sort_rand_cpu,
    sort_rand_gpu]

pop_size = np.logspace(3, 7, 10, dtype=int)

d = []
for n_p in pop_size:
    print(f"Testing functions with {n_p} points.")
    for fn in idx_fns:
        num_samples = int(0.1 * n_p)

        t0 = benchmark.Timer(
            stmt=f"{fn.__name__}(n_p, num_samples)",
            setup=f"from __main__ import {fn.__name__}",
            globals={'n_p': n_p, 'num_samples': num_samples}
        )

        d.append({
            'Population': n_p,
            'Samples': num_samples,
            'Function':fn.__name__,
            'Time': t0.timeit(11).raw_times[0]
        })

df = pd.DataFrame(d)

fig, ax = plt.subplots()

for fn in idx_fns:
    idxs = df['Function'] == fn.__name__
    ax.plot(df['Population'][idxs], df['Time'][idxs], label=fn.__name__)

ax.set_title("Runtime of Index-Generating Methods for Randomly Selecting 10% of Population")
ax.set_xlabel('Population Size')
ax.set_ylabel('Time')
ax.set_yscale('log')
ax.legend(bbox_to_anchor=(1, 0.7))
plt.savefig('figs/randperm.png', bbox_inches='tight')
plt.show()
8 Likes

I had the same problem and came up with an additional way to implement my own, and it seems to work fairly well actually. I ran it with the code above (I’ll post my exact code below).

Do note that this is only highly useful if you don’t care about having random shuffles, but rather just random slices. If you want random shuffles, it has the same speed as randperm, more or less. This is due to torch.unique currently automatically sorting the array in the cuda case.

import numpy as np
import torch
import random
import pandas as pd
import matplotlib.pyplot as plt
import torch.utils.benchmark as benchmark


def rand(pop_size, num_samples):
    """Use random.sample to generate indices."""
    return torch.Tensor(random.sample(range(pop_size), num_samples)).to("cuda")


def mult_cpu(pop_size, num_samples):
    """Use torch.Tensor.multinomial to generate indices on a CPU tensor."""
    p = torch.ones(pop_size) / pop_size
    return p.multinomial(num_samples=num_samples, replacement=False).to("cuda")


def mult_gpu(pop_size, num_samples):
    """Use torch.Tensor.multinomial to generate indices on a GPU tensor."""
    p = torch.ones(pop_size, device="cuda") / pop_size
    return p.multinomial(num_samples=num_samples, replacement=False)


def perm_cpu(pop_size, num_samples):
    """Use torch.randperm to generate indices on a CPU tensor."""
    return torch.randperm(pop_size)[:num_samples].to("cuda")


def perm_gpu(pop_size, num_samples):
    """Use torch.randperm to generate indices on a GPU tensor."""
    return torch.randperm(pop_size, device="cuda")[:num_samples]


def perm_gpu_f32(pop_size, num_samples):
    """Use torch.randperm to generate indices on a 32-bit GPU tensor."""
    return torch.randperm(pop_size, dtype=torch.int32, device="cuda")[:num_samples]


def sort_rand_cpu(pop_size, num_samples):
    """Generate a random torch.Tensor (CPU) and sort it to generate indices."""
    return torch.argsort(torch.rand(pop_size))[:num_samples]


def sort_rand_gpu(pop_size, num_samples):
    """Generate a random torch.Tensor (GPU) and sort it to generate indices."""
    return torch.argsort(torch.rand(pop_size, device="cuda"))[:num_samples]


def torchrand(pop_size, num_samples, device: str):
    vec = torch.unique(
        (torch.rand(num_samples, device=device) * pop_size).floor().long()
    )
    # Eliminate all duplicate entries. Might slow down the procedure but totally worth it.
    while vec.shape[0] != num_samples:
        vec = torch.unique(
            torch.cat(
                [
                    vec,
                    (torch.rand(num_samples - vec.shape[0], device=device) * pop_size)
                    .floor()
                    .long(),
                ]
            )
        )

    return vec.view(-1)


def torchrand_cpu(pop_size, num_samples, shuffle=False):
    return (
        torchrand(pop_size, num_samples, device="cpu")
        if shuffle is False
        else torch.randperm(pop_size)
    )  # [:num_samples]


def torchrand_gpu(pop_size, num_samples, shuffle=False):
    return (
        torchrand(pop_size, num_samples, device="cuda")
        if shuffle is False
        else torch.randperm(pop_size)
    )  # [:num_samples]


idx_fns = [
    rand,
    mult_cpu,
    mult_gpu,
    perm_cpu,
    perm_gpu,
    perm_gpu_f32,
    sort_rand_cpu,
    sort_rand_gpu,
    torchrand_cpu,
    torchrand_gpu,
]

pop_size = np.logspace(3, 7, 10, dtype=int)
d = []
for n_p in pop_size:
    print(f"Testing functions with {n_p} points.")
    for fn in idx_fns:
        num_samples = int(0.1 * n_p)

        t0 = benchmark.Timer(
            stmt=f"{fn.__name__}(n_p, num_samples)",
            setup=f"from __main__ import {fn.__name__}",
            globals={"n_p": n_p, "num_samples": num_samples},
        )

        d.append(
            {
                "Population": n_p,
                "Samples": num_samples,
                "Function": fn.__name__,
                "Time": t0.timeit(11).raw_times[0],
            }
        )

df = pd.DataFrame(d)

fig, ax = plt.subplots()

for fn in idx_fns:
    idxs = df["Function"] == fn.__name__
    ax.plot(df["Population"][idxs], df["Time"][idxs], label=fn.__name__)

ax.set_title(
    "Runtime of Index-Generating Methods for Randomly Selecting 10% of Population"
)
ax.set_xlabel("Population Size")
ax.set_ylabel("Time")
ax.set_yscale("log")
ax.legend(bbox_to_anchor=(1, 0.7))
plt.savefig("figs/randperm.png", bbox_inches="tight")
plt.show()


1 Like

Hi,
Both randperm and multinomial cannot solve my problem cause they cannot generate multi-dimension tensor

In numpy:

sample_set = [0.4,0.2,0.1]
np.random.sample(sample_set,shape=[3,3,3,3],p=[0.33,0.33,0.34])

Is there any equivalent method in torch?

You can just sample in a 1D array and then reshape it to your desired size.

1 Like