Hello,
I’m trying to find a way to prevent a painfully slow for loop in Pytorch. Basically, I have a tensor, and I want to split it up into pieces and feed those pieces into my model, similar in spirit to a grouped convolution of sorts.
self.C = C
self.block = Block(C, 3, 64)
def forward(self, x):
x_shape = x.shape
x = torch.flatten(x, start_dim=1, end_dim=-1).unsqueeze(0)
x = torch.split(x, self.C, -1)
attention = []
for i in x:
attended = self.block(i)
attention.append(attended)
attention = torch.stack(attention, 1)
Small values of C, alongside a large tensor, makes this operation surprisingly much slower, due to the Python for-loop the above code runs through. Is there any way to fix this issue other than the slightly-hacky way of exchanging the batch dimension for the num_block dimension and then running it through the model?
Thanks.