Hi gang, been a while . So I have this simple model:
class ModelSerial(nn.Module):
def __init__(self, in_features, hidden, out_features):
super().__init__()
blocks = [
nn.Sequential(
nn.Linear(in_features, hidden),
nn.ReLU(inplace=True),
nn.Linear(hidden, hidden),
nn.ReLU(inplace=True),
nn.Linear(hidden, 1),
)
for _ in range(out_features)
]
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
out = torch.cat([b(x) for b in self.blocks], dim=-1)
return out
Notice how each sequential blockâs output is not chained in the forward pass. Rather, each is executed on the input separately, and the results are simply concatenated. As one would image, it runs extremely slow, especially as the number of output features increases or the depth of the network increases. This is due the Python for-loop. I had considered trying torch script jit, but that wasnât an option for this particular engagement. So next best thing? Try and parallelize it with grouped convolutions. Hereâs my attempt:
class ModelParallel(nn.Module):
def __init__(self, in_features, hidden, out_features):
super().__init__()
self.out_features = out_features
self.block = nn.Sequential(
nn.Conv1d(
in_features * self.out_features,
hidden * self.out_features,
kernel_size=1,
groups=self.out_features
),
nn.ReLU(inplace=True),
nn.Conv1d(
hidden * self.out_features,
hidden * self.out_features,
kernel_size=1,
groups=self.out_features
),
nn.ReLU(inplace=True),
nn.Conv1d(
hidden * self.out_features,
out_features,
kernel_size=1,
groups=self.out_features
)
)
def forward(self, x):
# torch.Size([B, 150])
bs = x.shape[0]
# torch.Size([B, 150, 1])
x = x[:, :, None]
# torch.Size([B, 150, 50])
x = x.expand(bs, x.shape[1], self.out_features)
# torch.Size([B, 7500, 1]) because conv1d works on [B C L]
x = x.reshape(bs, -1, 1)
# torch.Size([B, 50, 1])
x = self.block(x)
# torch.Size([B, 50])
x = x.squeeze(dim=-1)
return x
The model compiles. Input size and output size are producing what is expected. But it doesnât train at all. Not even close. Validation and Train loss are completely out of sync, itâs a mess.
I believe my lack of understanding how groups are computed are what are messing things up. So I even tried re-arranging the setup as follows in case the groups are âinterlacedâ as opposed to âchunkedâ:
########
# Change:
# torch.Size([B, 150, 1])
x = x[:, :, None]
# torch.Size([B, 150, 50])
x = x.expand(bs, x.shape[1], self.out_features)
########
# Into:
# torch.Size([B, 1, 150])
x = x[:, None, :]
# torch.Size([B, 50, 150])
x = x.expand(bs, self.out_features, x.shape[2])
But unfortunately, that still didnât work (train). Epoch time was a hell of a lot faster though, as desired. My example tests:
ms = ModelSerial(
in_features=150,
hidden=10,
out_features=50
)
mp = ModelSerial(
in_features=150,
hidden=10,
out_features=50
)
I wouldnât expect the results to be exact, I mean, linear and conv are using different weight initialization schemes even if deterministic is set. But I would at least expect the parallel model to train. PyTorch senpais, can you please provide me with some guidance on how I can architect this problem to execute in parallel? Thank you!!