nn.Conv2 and nn.Linear in 3d points operation

In PointNet++, suppose the input point data shape is [1, 3, 1024, 32] where the 1024 the centroids, and 32 is the neighbors and 3 is xyz coordinate with batch size as 1.

The input data is passed into nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(1, 1))

Can we use nn.Linear instead of nn.Conv2d:
Change the input point data shape as [1 * 1024 * 32, 3], then pass the data into nn.Linear(3, 64)

I am wondering if these two have the same effect, since the conv2d kernel size is (1, 1)?

Hi Elijah!

Yes (although I’m not sure what the point would be).

Consider the following example (with fewer centroids and neighbors for
convenience):

>>> import torch
>>> print (torch.__version__)
2.2.1
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> conv = torch. nn.Conv2d (in_channels = 3, out_channels = 5, kernel_size = (1, 1))
>>>
>>> x = torch.randn (1, 3, 2, 4)
>>>
>>> resConv = conv (x)
>>>
>>> lin = torch.nn.Linear (3, 5)
>>> with torch.no_grad():
...     _ = lin.weight.copy_ (conv.weight.squeeze())
...     _ = lin.bias.copy_ (conv.bias)
...
>>> resLin = lin (x.permute (0, 2, 3, 1).flatten (0, 2)).reshape (1, 2, 4, 5).permute (0, 3, 1, 2)
>>>
>>> torch.allclose (resConv, resLin)
True

Best.

K. Frank

Thank you. This is what I need.