Hello everyone,
I stumbled across the following when playing around with torch.vmap
in torch version 2.4.0
. Consider the following complete code snippet:
import torch
func = torch.vmap(lambda x, kernel: torch.nn.functional.conv3d(x, kernel, padding="same"), in_dims=(0, None))
data = torch.rand(16, 16, 16, 32, 32, 32) # [Group, Batch, Feature, X, Y, Z]
kernel = torch.rand(1, 16, 3, 3, 3) # [OutFeature, InFeature, X, Y, Z]
res = func(data, kernel)
print(res.shape) # torch.Size([Group, Batch, OutFeature, X, Y, Z])
This snippet works as expected, and the yielded shape is [16, 16, 1, 32, 32, 32]
. Lets say now that we want to try doing this convolution using channels last to improve memory access patterns.
We change the lambda within the vmap
as follows:
func = torch.vmap(lambda x, kernel: torch.nn.functional.conv3d(x.to(memory_format=torch.channels_last_3d), kernel, padding="same"), in_dims=(0, None))
I thought that this would work, but it actually causes vmap
to throw the following:
RuntimeError: required rank 5 tensor to use channels_last_3d format
This leads me to believe that I do not understand either how vmap
or how channels_last_3d
works. In my mind the tensor should be 5D since this operation runs for every Group
. In fact, printing the tensor’s shape within the function gives a 5D shape. Furthermore, I see no reason for channels_last_3d
not to work, as the memory is contiguous and (as I understand it) the memory format casting doesn’t change anything, it just passes information down to select which convolution algorithm will be used.
The obvious workaround in my case is to flatten the two batch dimensions, which I can work with. But it would be interesting to know what is going on regardless.
Thank you very much for your time,
Mike