Custom 3d convolutions with 2d "internal convolutions"

Usually in a convolution we just multiply the kernel with the input and return the scalar, and then move the kernel and repeat the process. The problem I have is that instead of returning a scalar I have an “internal” 2d plane to convolve over, at each 3d point. So the 3d convolution has to return, instead of a scalar, a 2d convolution result.

I am not sure how to implement something like this. A good starting step for me would be if I could efficiently extract 3x3x3 chunks of the input in the same way that pytorch does for regular 3d convolutions. I have tried looking at conv3d in torch/nn/quantized/functional.py but cant seem to find the part the actual convolutions are being taken.

To create these “patches” you could use nn.Unfold, which also shows the manual conv operation in its example.