Understanding torch.vmap with torch.channels_last_3d

Hello everyone,

I stumbled across the following when playing around with torch.vmap in torch version 2.4.0. Consider the following complete code snippet:

import torch

func = torch.vmap(lambda x, kernel: torch.nn.functional.conv3d(x, kernel, padding="same"), in_dims=(0, None))

data = torch.rand(16, 16, 16, 32, 32, 32) # [Group, Batch, Feature, X, Y, Z]
kernel = torch.rand(1, 16, 3, 3, 3) # [OutFeature, InFeature, X, Y, Z]
res = func(data, kernel)

print(res.shape) # torch.Size([Group, Batch, OutFeature, X, Y, Z])

This snippet works as expected, and the yielded shape is [16, 16, 1, 32, 32, 32]. Lets say now that we want to try doing this convolution using channels last to improve memory access patterns.
We change the lambda within the vmap as follows:

func = torch.vmap(lambda x, kernel: torch.nn.functional.conv3d(x.to(memory_format=torch.channels_last_3d), kernel, padding="same"), in_dims=(0, None))

I thought that this would work, but it actually causes vmap to throw the following:

RuntimeError: required rank 5 tensor to use channels_last_3d format

This leads me to believe that I do not understand either how vmap or how channels_last_3d works. In my mind the tensor should be 5D since this operation runs for every Group. In fact, printing the tensor’s shape within the function gives a 5D shape. Furthermore, I see no reason for channels_last_3d not to work, as the memory is contiguous and (as I understand it) the memory format casting doesn’t change anything, it just passes information down to select which convolution algorithm will be used.

The obvious workaround in my case is to flatten the two batch dimensions, which I can work with. But it would be interesting to know what is going on regardless.

Thank you very much for your time,
Mike

I can reproduce this using the latest nightly binary. CC @richard would you know why the rank 5 tensor is not detected?

1 Like