I have a PyTorch model that consists of multiple independent FullyConnectedNetwork
instances stored inside an nn.ModuleList
. Here’s the code:
import torch.nn as nn
class FullyConnectedNetwork(nn.Module):
def __init__(self):
super(FullyConnectedNetwork, self).__init__()
self.fc1 = nn.Linear(20, 10)
self.fc2 = nn.Linear(10, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
class ParallelFCN(nn.Module):
def __init__(self, n):
super(ParallelFCN, self).__init__()
self.models = nn.ModuleList([FullyConnectedNetwork() for _ in range(n)])
def forward(self, x):
outputs = [model(x[:, i*20:(i+1)*20]) for i, model in enumerate(self.models)]
return torch.cat(outputs, dim=1)
# Example usage:
n = 1000
model = ParallelFCN(n)
print(model)
Currently, I’m using a for-loop to pass data through each FullyConnectedNetwork
instance. However, I realize that this approach is not truly parallel in a software sense.
Given that each FullyConnectedNetwork
is independent of the others, is there a way to run them truly in parallel, perhaps using multi-threading, multi-processing, or any other method in PyTorch?
I need it because the number of my modules can get really big, as big as 400, and processing then using a for loop is very slow.