Im2col im2col_backward ops lowering not supported in xla. Any temporary workaround to replace `nn.unfold` without trigger the unsupported ops?

I am trying to speed up my pytorch code by using pytorch_lightning, which supports TPU training. However, the training runs very slowly. Each iter takes 20secs in TPU, whereas only takes half a sec in GPU.

Upon profiling, the issue is below:

pt-xla-profiler: TransferFromServerTime too frequent: 449 counts during 6 steps
pt-xla-profiler: Op(s) not lowered: aten::im2col, aten::im2col_backward,  Please open a GitHub issue with the above op lowering requests.

XLA does not support aten::im2col, aten::im2col_backward

Upon investigation of the pytorch code, it is nn.unfold that is triggering im2col and im2col_backward

I can’t afford to rent multi-GPU instances and really need to use TPU to speed up things.

Wihout XLA supporting im2col and im2col_backward any time soon, is there a way to replace what nn.unfold is doing but without triggering the 2 ops ?

Hi Cloud_Huang, Could you provide a sample of the code that is using unfold and fold?

Here you go

class PixelEmbed(nn.Module):
    """ Image to Pixel Embedding

    def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
        num_patches = (img_size // patch_size) ** 2
        self.img_size = img_size
        self.num_patches = num_patches
        self.in_dim = in_dim
        new_patch_size = math.ceil(patch_size / stride)
        self.new_patch_size = new_patch_size

        self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
        self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)

    def forward(self, x, pixel_pos):
        B, C, H, W = x.shape
        assert H == self.img_size and W == self.img_size, \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."
        x = self.proj(x)
        x = self.unfold(x)
        x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size)
        x = x + pixel_pos
        x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
        return x