I am trying to train a list of neural networks in parallel, using nn.ModuleList.
class Modules(nn.Module):
def __init__(self, n_models):
super().__init__()
self.modules = nn.ModuleList([nn.Sequential(nn.Linear(1, n), nn.Linear(n, 1)) for _ in range(n_models)])
def forward(self, x):
y = torch.cat([i_module(x) for i_module in self.modules], dim=-1)
return y
...
def train_epoch(model, criterion, optimizer, train_dl):
model.train()
criterion.train()
for data, target in train_dl:
optimizer.zero_grad(set_to_none=True)
prediction = model(data)
loss = criterion(input=prediction, target=target) # shape [n_modules x 1]
loss = loss.sum() # loss.mean()
loss.backward()
optimizer.step() # update weights
pass
...
During training, I would like to use the respective loss value to update the respective module only, and prevent loss[0] from affecting the update of module[1].
Rather than reducing the loss vector to a scalar via sum() or mean() and affecting the optimization of all modules, I would like to use the actual loss value (loss[i]) to update the respective module (self.modules[i]) and obtain independently trained models.
I have read that backward(), if non-scalar, requires the gradients, which i am unsure how to compute manually, without messing up the computation graph.
Also, iterating the loss vector is probably not effective, since the full graph will be overwritten with each iteration.
for i in range(loss.shape[0]):
loss[i].backward(retain_graph=True)
optimizer.step()
The question is probably related to whether PyTorchs gradient
$\frac{\delta loss.sum()}{\delta prediction[0]} == \frac{\delta loss[0]}{\delta prediction[0]}$