How to use k channels in CNN for k FC Layers

I have an encoder, which outputs a tensor with shape (bn, c * k, 32, 32). I now want produce k means with shape (bn, k, 1, 2). So the means are 2-dim coordinates. To do so, I want to use k FC Layers, while for each mean k_i I only want to use c channels.

So my idea is, that I reshape the encoder output out to a 5d tensor with shape (bn, k, c, 32, 32). Then I can use the flattened out[:, 0]out[:, k] as input for the k linear layers.

The trivial solution would be to define the linear layers manually:

self.fc0 = nn.Linear(c * 32 * 32, 2)
self.fck = nn.Linear(c * 32 * 32, 2)

Then I could define the forward pass for each mean as follows:

mean_0 = self.fc0(out[:, 0].reshape(bn, -1))
mean_k = self.fck(out[:, k].reshape(bn, -1))

Is there a more efficient way to do that?

Edit: To give a little more information about the background: I want to find the keypoints/landmarks of the input image. The idea is, that I assign each keypoint k a number of feature maps c. That’s, why the output of the encoder has the shape (bn, k, c, 32, 32) . Now I want those c feature maps to predict its keypoint, which is a 2-dim coordinate. Essentially, I want a separate fc layer for each keypoint.

I have come to this solution using the group function of nn.Conv2d:

bn = 1; k = 5; c = 3
x = torch.rand(bn, k*c, 32, 32)
m = nn.Conv2d(in_channels=c*k, out_channels=2*k, kernel_size=32, groups=k)