Unfold function takes too much GPU memory


I am trying to implement the HaloNets, I used the halo attention from timm library.

The HalloAttn object is expected to replace the convolution layer. However when I test with the input size [2, 512, 256, 256] with two modules:

  • a Halo layer with 789,888 parameters, attention with 8 heads, block size 8, halo size 3.
  • a convolutional layer with 2,359,808 parameters, nn.Conv2d(512,512,3,1,0)

With torch.no_grad(), the HaloAttn took 1.44 GB for inference and the Convolution took 0.21 GB. The conv cost less even though the number of parameters is three times HalloAttn. Therefore, I dug a little deeper and test GPU memory at each module in HaloAttn, the result shows the unfold function (code snip below) cost 0.4GB.

kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
            B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)

Is there any alternative to unfold with more memory-friendly? Or can someone let me know why this happens?

unfold explicitly creates patches/windows of the input tensor, which could increase the memory usage depending on the stride and window size.
I guess the unfold approach is used for a “manual” conv implementation using a matmul?
If so, then a higher memory usage would be expected as conv layers could use other kernels, which do not work on the unfolded input.

1 Like