I would like to create a model with Pytorch that has several identical models not connected by autograd, but would train in parallel. However, I’m running into issues with performing the backward pass on the current way I’ve structured it. How would you structure this to run concurrent models in a single model? I.e. the inputs flow through separate parallel paths and should have their loss and backward prop be kept separate.
import torch
import torch.nn as nn
class DummyModel(nn.Module):
def __init__(self, in_size, hidden_size, out_size, num_advisors):
super().__init__()
self.n1 = in_size
self.h1 = hidden_size
self.n2 = out_size
self.num_advisors = num_advisors
self.modlist = nn.ModuleList()
for _ in range(self.num_advisors):
self.modlist.append(
nn.Sequential(nn.Linear(self.n1, self.h1), nn.Linear(self.h1, self.h1), nn.Linear(self.h1, self.n2)))
self.sigm = nn.Sigmoid()
def forward(self, x):
return self.sigm(torch.stack([_(x) for _ in self.modlist], dim=1))
in_size = 10
hidden_size = 64
out_size = 1
num_advisors = 20
batch_size = 100
model = DummyModel(in_size, hidden_size, out_size, num_advisors)
loss_func = nn.BCELoss()
dummy_inputs = torch.rand((batch_size, in_size))
dummy_targets = torch.rand((batch_size, num_advisors, 1))
dummy_outputs= model(dummy_inputs)
loss=torch.stack([loss_func(dummy_outputs[:,x,:], dummy_targets[:,x,:]) for x in range(dummy_outputs.size()[1])])
loss.backward()
The error is:
Traceback (most recent call last):
File "scratch.py", line 38, in <module>
loss.backward()
File "pythonProject1\lib\site-packages\torch\_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "pythonProject1\lib\site-packages\torch\autograd\__init__.py", line 166, in backward
grad_tensors_ = _make_grads(tensors, grad_tensors_, is_grads_batched=False)
File "pythonProject1\lib\site-packages\torch\autograd\__init__.py", line 67, in _make_grads
raise RuntimeError("grad can be implicitly created only for scalar outputs")
RuntimeError: grad can be implicitly created only for scalar outputs