Speeding up huge linear layer?

Edit: Replaced an earlier post after some investigating.

I’ve been trying to predict a 24x32 map with 42 classes from a 256 dimensional vector. That’s about 32,000 parameters, however the hidden layers come in BATCH_SIZE x 36 groups. This has become a bottleneck in my model. What are my options for speeding this up? Are there options other than parallelizing (I have access up to 4 GPUs).

Here is the bottleneck with dimensions in the comments:

# visual_feat.shape = BATCH_SIZE x 36 x 256
# img_features =  BATCH_SIZE x 36 x 2048
# map_dims = 24x32
# num classes = 42
self.map_pred(visual_feat).reshape(img_features.shape[0], self.num_classes, *img_features.shape[1:-1],*self.map_dims)