Slow For Loop On Small Inputs?


I’m trying to find a way to prevent a painfully slow for loop in Pytorch. Basically, I have a tensor, and I want to split it up into pieces and feed those pieces into my model, similar in spirit to a grouped convolution of sorts.

        self.C = C
        self.block = Block(C, 3, 64)
    def forward(self, x):
        x_shape = x.shape
        x = torch.flatten(x, start_dim=1, end_dim=-1).unsqueeze(0)
        x = torch.split(x, self.C, -1)
        attention = []
        for i in x:
            attended = self.block(i)
        attention = torch.stack(attention, 1)

Small values of C, alongside a large tensor, makes this operation surprisingly much slower, due to the Python for-loop the above code runs through. Is there any way to fix this issue other than the slightly-hacky way of exchanging the batch dimension for the num_block dimension and then running it through the model?