I’m trying to find a way to prevent a painfully slow for loop in Pytorch. Basically, I have a tensor, and I want to split it up into pieces and feed those pieces into my model, similar in spirit to a grouped convolution of sorts.
self.C = C self.block = Block(C, 3, 64) def forward(self, x): x_shape = x.shape x = torch.flatten(x, start_dim=1, end_dim=-1).unsqueeze(0) x = torch.split(x, self.C, -1) attention =  for i in x: attended = self.block(i) attention.append(attended) attention = torch.stack(attention, 1)
Small values of C, alongside a large tensor, makes this operation surprisingly much slower, due to the Python for-loop the above code runs through. Is there any way to fix this issue other than the slightly-hacky way of exchanging the batch dimension for the num_block dimension and then running it through the model?