Slow arm64 Conv3d on Apple Silicon CPU

I’m noticing that for large input images, Conv3d is substantially slower (4x) and more memory intensive when using python/torch built for arm64. I’m seeing this on a 32GB M1 Pro running macOS 13.0.1.

I’ve installed a prebuilt x86_64 python distribution and a native arm64 python distribution. Both have torch 1.13.1 installed via conda. Here’s the info I get from /usr/bin/time running the below example with each distribution:

python_x86_64: 31.96 seconds, 6.6 GB peak memory
python_arm64: 131.09 seconds, 26.2 GB peak memory

import torch


class ConvNet(torch.nn.Module):

    def __init__(self, ndims):
        super().__init__()
        channels = [1, 32, 32]
        self.convs = torch.nn.ModuleList()
        for i in range(len(channels) - 1):
        	self.convs.append(torch.nn.Conv3d(channels[i], channels[i + 1], 3, 1))
        	self.convs.append(torch.nn.LeakyReLU(0.2))

    def forward(self, x):
        for conv in self.convs:
            x = conv(x)
        return x

device = 'cpu'
torch.set_num_threads(1)

image = torch.rand(1, 1, 192, 256, 256)
model = ConvNet(ndims=3).to(device)
output = model(image)

This only seems to be the case for large inputs. In fact, when using an input of size (128, 128, 128), the arm64 version is faster (but still larger memory footprint). I don’t notice this for 2D convolutions. Any idea what is causing this?