Training multiple models at the same time

Is it possible to train multiple models simultaneously?
For instance, suppose by nettwork class is Net.

net1 = Net()
net2 = Net()

Is it possible to train net 1 and net 2 simultaneously?

Thanks.

Hi,

Yes of course. Do you encounter any issue doing this?

I don’t know how to do this exactly. Is there any specific package I need to use?

Thanks.

I think you want to be more precise on what you mean simultaneously.
You can have one optimizer for each model and just train them in one training loop. Either with the same data or not.

Hey @albanD, I’m trying to train 2 models in one training loop on the same data. I have a separate optimizer for each one. However, I get this error -

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Aren’t 2 separate computation graphs created for the models ?

1 Like

Hi,

If you do two forward calls yes. But if you reuse some computations from one to the other no.
You can pass retain_graph=True to the first backward call to make sure the second will succeed.

What I’m trying to do in the training loop is use the output of one model as the the ground truth label of the other.
Training loop -

for e in range(epochs):

  epoch_loss_t = 0
  epoch_loss_s = 0

  for i, (data, label) in enumerate(trainloader):

    data, label = data.to(device), label.to(device)

    t_out = teacher(data)
    t_loss = F.cross_entropy(t_out, label)
    t_loss.backward()
    t_optim.step()
    t_optim.zero_grad()
    epoch_loss_t += t_loss.item()

    s_out = student(data)
    s_loss = (1 - alpha) * F.cross_entropy(s_out, label)
    s_loss += alpha * loss_fn(s_out, t_out)
    s_loss.backward()
    s_optim.step()
    s_optim.zero_grad()
    epoch_loss_s += s_loss.item()

Is there a way I can somehow detach the output of the the first model so that a separate computation graph is created for the second one ? I’m guessing that simply doing a .detach() on the output of the first model should work.

Yes, doing loss_fn(s_out, t_out.detach()) will break the link between the two :slight_smile:

6 Likes