Chaining pytorch modules

I am trying to chain two torch.nn.Modules in python and then jointly train their parameters. Suppose they are module_a and module_b.

When simplified, my code looks something like this

import torch

device = torch.device("cuda:0")

module_a = ModuleA()
module_b = ModuleB()
module_a.to(device)
module_b.to(device)

all_params = chain(module_a.parameters(), module_b.parameters())
optimizer = torch.optim.Adam(all_params)

critereon = torch.nn.CrossEntropyLoss(ignore_index=UNCERTAIN)

net_input = get_input_tensor()
label = get_label()

for i in range(100000):
    optimizer.zero_grad()

    out_a = module_a.forward(net_input)
    out_b = module_b.forward(backbone_out)
    loss = critereon(out_b, label.to(device))
    loss.backward()

    if i % 10 == 0:
        print("loss", loss.item())

    optimizer.step()

I found that the loss goes down initially, but then gets stuck at a very high value. My hypothesis is that the backprop is not going through to module_a's parameters. When I hook the loss up directly to out_a, the loss decreases to a value very close to zero.

Is this use case supported by pytorch?

It seems I have made a mistake in my implementation of module_b and now it seems like it’s working.