I need to use an equivalent of `tf.dynamic_partition`

in PyTorch. Is there anything with similar functionality in PyTorch or other library or is there a simple and clever way to code it for PyTorch and work fast? The same for `tf.dynamic_stitch`

. Thanks!

The following implementation works as equivalent of Vector Partitions in the tf documentation.

```
import torch
def dynamic_partition(data, partitions, num_partitions):
res = []
for i in range(num_partitions):
res += [data[(partitions == i).nonzero().squeeze(1)]]
return res
data = torch.Tensor([10, 20, 30, 40, 50])
partitions = torch.Tensor([0, 0, 1, 1, 0])
dynamic_partition(data, partitions, 2)
```

There are ways of implementing this more efficiently, specially for big CUDA tensors. But maybe this is good enough for your application.

Regarding the `tf.dynamic_stitch`

, the following snippet works as well (it matches the input->output from the tf documentation).

```
indices = [None] * 3
indices[0] = 6
indices[1] = [4, 1]
indices[2] = [[5, 2], [0, 3]]
data = [None] * 3
data[0] = [61, 62]
data[1] = [[41, 42], [11, 12]]
data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]
data = [torch.tensor(d) for d in data]
indices = [torch.tensor(idx) for idx in indices]
def dynamic_stitch(indices, data):
n = sum(idx.numel() for idx in indices)
res = [None] * n
for i, data_ in enumerate(data):
idx = indices[i].view(-1)
d = data_.view(idx.numel(), -1)
k = 0
for idx_ in idx: res[idx_] = d[k]; k += 1
return res
dynamic_stitch(indices, data)
```

If these implementations don’t perform well enough for you application, consider the possibility of implementing an extension.

2 Likes

Great! Thanks for the answer!