I recommend PyTorch provide a method to apply multiple independent nn modules in parallel, such as in a segregated layer. At a minimum I recommend they provide a wrapper function for the nn.Conv1d workaround used to parallelise the execution of independent nn.Linear layers. For example;
segregatedLinear = LinearSegregated(in_features=in_features, out_features=out_features, number_sublayers=linearSublayersNumber)
class LinearSegregated(nn.Module):
def __init__(self, in_features, out_features, number_sublayers):
super().__init__()
self.segregatedLinear = nn.Conv1d(in_channels=in_features*number_sublayers, out_channels=out_features*number_sublayers, kernel_size=1, groups=number_sublayers)
self.number_sublayers = number_sublayers
def forward(self, x):
#x.shape = batch_size, number_sublayers, in_features
x = x.view(x.shape[0], x.shape[1]*x.shape[2], 1)
x = self.segregatedLinear(x)
x = x.view(x.shape[0], self.number_sublayers, x.shape[1]//self.number_sublayers)
#x.shape = batch_size, number_sublayers, out_features
return x