Equivalent of `tf.dynamic_partition`

I need to use an equivalent of tf.dynamic_partition in PyTorch. Is there anything with similar functionality in PyTorch or other library or is there a simple and clever way to code it for PyTorch and work fast? The same for tf.dynamic_stitch. Thanks!

The following implementation works as equivalent of Vector Partitions in the tf documentation.

import torch

def dynamic_partition(data, partitions, num_partitions):
  res = []
  for i in range(num_partitions):
    res += [data[(partitions == i).nonzero().squeeze(1)]]
  return res

data = torch.Tensor([10, 20, 30, 40, 50])
partitions = torch.Tensor([0, 0, 1, 1, 0])
dynamic_partition(data, partitions, 2)

There are ways of implementing this more efficiently, specially for big CUDA tensors. But maybe this is good enough for your application.

Regarding the tf.dynamic_stitch, the following snippet works as well (it matches the input->output from the tf documentation).

indices = [None] * 3

indices[0] = 6
indices[1] = [4, 1]
indices[2] = [[5, 2], [0, 3]]

data = [None] * 3

data[0] = [61, 62]
data[1] = [[41, 42], [11, 12]]
data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]

data = [torch.tensor(d) for d in data]
indices = [torch.tensor(idx) for idx in indices]

def dynamic_stitch(indices, data):
  n = sum(idx.numel() for idx in indices)
  res  = [None] * n
  for i, data_ in enumerate(data):
    idx = indices[i].view(-1)
    d = data_.view(idx.numel(), -1)
    k = 0
    for idx_ in idx: res[idx_] = d[k]; k += 1
  return res

dynamic_stitch(indices, data)

If these implementations don’t perform well enough for you application, consider the possibility of implementing an extension.

2 Likes

Great! Thanks for the answer!