Splitting a batch into equal chunks without torch.split

Is there a way to split a batch of say N by K where N is the number of samples with K being the feature dimensionality into equal chunks? Torch.split() does the job but it, unfortunately, returns a list which I don’t think you can backprop through. torch.view() could also work but i’m not sure on the order of traversal as I need make sure that each chunk in the new split tensor corresponds to the same order in the unsplit tensor, i.e. if i had a tensor from [1,2,3,4,5,6] and split by two i would like to preserve [[1,2], [3,4], [5,6]].

1 Like

Could you post the shapes of your data?
Your example is quite simple as you already said, so let’s have a look, which method will be the most suitable. :wink:

Hey so here’s my actual use case, I’m using a pretrained LSTM based sentence encoder (InferSent) that consumes a group of sentences, i.e. cliques. So a batch looks like [ L, N, D] where L is the sentence length, N is the number of sentences, and D is the embedding dimension of a word in the sentence. This batch is then fed into the LSTM encoder which outputs [N, D2] where D2 is the sentence embedding dimension. Now the exact challenge is that each row is 1 sentence but I want to work with a group of consecutive sentences, cliques,. That is to say for a clique size of 3 rows [0:2] should be grouped together with the ordering preserved. So I want to somehow reshape this batch into [N / 3 , 3 , D2] and feed it into some other downstream model and still be end to end differentiable. Torch.split() returns a list which would break the computation graph. I’m currently using .view() to get around this problem but im not sure if thats the best practice here.

Seems like a valid approach for me. Currently I don’t know another method for your use case.
Could you check, if this order is right:

N = 9
D2 = 10

x = torch.arange(N).unsqueeze(1).repeat(1, D2)
print(x)
print(x.view(N/3, 3, D2))

As you can see, the D2 dimension is still the same.