I am trying to chain two torch.nn.Module
s in python and then jointly train their parameters. Suppose they are module_a
and module_b
.
When simplified, my code looks something like this
import torch
device = torch.device("cuda:0")
module_a = ModuleA()
module_b = ModuleB()
module_a.to(device)
module_b.to(device)
all_params = chain(module_a.parameters(), module_b.parameters())
optimizer = torch.optim.Adam(all_params)
critereon = torch.nn.CrossEntropyLoss(ignore_index=UNCERTAIN)
net_input = get_input_tensor()
label = get_label()
for i in range(100000):
optimizer.zero_grad()
out_a = module_a.forward(net_input)
out_b = module_b.forward(backbone_out)
loss = critereon(out_b, label.to(device))
loss.backward()
if i % 10 == 0:
print("loss", loss.item())
optimizer.step()
I found that the loss goes down initially, but then gets stuck at a very high value. My hypothesis is that the backprop is not going through to module_a
's parameters. When I hook the loss up directly to out_a
, the loss decreases to a value very close to zero.
Is this use case supported by pytorch?