Torch.split in batches

Hello,

Does anyone know how to split batches of tensors in different positions each?

I am trying to separate my tensors in two parts according to the split_position, which is different for each sample

batch.shape = torch.Size([128, 68, 1])      # (batch_size, max_len, 1)
split_positions.shape = ([128])             # split position of each sample

Thanks,

As a result you would expect a list of 128 tupels with two tensors inside?

I would like to have 2 tensors of the same size([128,68,1]) but one containing the left part of each sample and a lot of zeros to match the max_len which is 68, and the other containing the right part and then a lot of zeros.

This is an example for a single, small sample:

max_len = 10
split = 3
left_weights = torch.tensor ([10,3,1,20,4,0,0,0,0,0])
# separate in the split position
l, m_l = left_weights[0:split], left_weights[split:]

# fill with zeros
l = F.pad(l, (0,max_len - len(l)))
m_l = F.pad(m_l, (0,max_len - len(m_l)))

The results should look like this:

l  = tensor([ 10,   3,   1,   0,   0,   0,   0,   0,   0,   0]),
m_l = tensor([ 20,   4,   0,   0,   0,   0,   0,   0,   0,   0])

Ok, thanks for the info.
I don’t know an elegant way to achieve your results.
If your results would look like this:

l  = tensor([10, 3, 1, 0, 0, 0, 0, 0, 0, 0]),
m_l = tensor([0, 0, 0, 20, 4, 0, 0, 0, 0, 0,])

, i.e. there the zeros in front of the right part, you could use .scatter_().

In case this helps, here is the code to achieve this, but unfortunately this isn’t exactly, what you wanted:

# Init data tensor and splits
x = torch.randn(20, 10)
splits = torch.empty(20, dtype=torch.long).random_(1, 10)

# Calculate size for result tensor
new_rows = x.size(0) * 2
offset = x.size(0)
z = torch.zeros(new_rows, 10)

# Calculate scatter indices
# The "left" part goes into z[:offset], the "right" part goes into z[offset:]
split_scatter = [torch.cat((torch.ones(1, split, dtype=torch.long) * idx,
                            torch.ones(1, max_len-split, dtype=torch.long) * (idx + offset)),
                            dim=1) for idx, split in zip(range(0, new_rows), splits)]
split_scatter = torch.cat(split_scatter)

# Apply scatter
z.scatter_(0, split_scatter, x)

# Cut tensors
x1 = z[:offset]
x2 = z[offset:]
3 Likes

Yes, this also works for me, thank you very much!

@ptrblck If I have a tensor of shape [200, 5, 76] // Frames batchsize Voc
How I can use split to get a list of 5 that is the batch size , two dimensional tensors?
I want to evaluate and I need them individually to calculate edit distance on each of them.