Is there a way to realize 4d convolution using the ConvNd function

I found that conv1d, conv2d, conv3d use ConvNd = torch._C._functions.ConvNd for forward passing. I just want to know whether there is an efficient way to use ConvNd for 4dimension convolution, since my inputs are 4-dimension (6 dimension if count batchsize and channel.)

Thanks !!!


Not for the moment. We do not yet have nd convolution kernels, for n>3.

1 Like

What’s the blocking issue?

Since we have the ConvNd function, why can’t we implement a 4d convolution just like the others? Why not n-dimensional?

From the Python side, all the convolution layers ultimately delegate to that ConvNd function. Nothing seems special cased for any particular kernel shape.

I think the C++ side is implemented at torch/csrc/autograd/functions/convolution.{h,cpp}. The forward pass has a special case when n=3, but all it seems to do is reshape the kernel by adding/removing dummy dimensions. I’m not entirely sure why though. Perhaps it could be easily generalized? I’m a bit out of my element with these low-level details, so I may be missing something.

it’s mostly that we haven’t implemented the C bits going for ConvNd in the THNN/THCUNN libraries. (untimately autograd/functions/convolution.cpp dispatches to that or to CuDNN).

@smth could you please share if there are any developments on this front?

1 Like

@Edward_Hahn not really, because the demand for 4d convolution is not high. Your best bet is to create N Conv3D layers to do a Conv4D (there will be some boundary effects / issues, so you will have to pad edges appropriately)

Just wanted to chime in and say that I’m interested in this feature, it seems like the natural way to represent time series of multi-channel satellite images.

Wouldn’t this correspond to a 3D convolution, where the input would have the shape [batch_size, in_channels, seq, height, width] or is each “image” a volume?

Here you can find a fully functional convNd and nD transposed convolution based on conv3d as mentioned by smth.

But I’m also looking forward to a native 4D convolution, mainly due to speed concerns. I’m starting with the implementation, but indeed seems like a long way to go.


Thanks for your work!

I’d also LOVE to see this implemented. In my case, it’s a natural way to represent a time series of 3D dimensional microscopy measurements. Hope the pytorch team will consider implementing it.

1 Like

4D eats mad flops

JAX seems to have N-dim