How can I wait to finish training model B after model A is completed before applying loss.backward() for each epoch


Let’s say that I have 2 models A & B to train on one task, and I train both models at the same time.

For each epoch (or each dataloader batch), if model A completes training faster than model B, how can I write a statement so that model A can wait for model B and a loss is updated (loss.backward()) based on both models instead of either A or B. Is it necessary to write a wait statement, like sleep(), in the training loop? Thanks!


In PyTorch, you can achieve this by using the built-in to handle your data and training both models in the same loop. You don’t need to use a wait statement or sleep() in this case. Instead, you can compute the loss for both models in the same loop and update the gradients based on the combined loss. Here’s an example of how to do this:

import torch
import torch.nn as nn
import torch.optim as optim
from import DataLoader
from torchvision import datasets, transforms

# Load your dataset and create DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

# Define your models A and B
model_A = ...
model_B = ...

# Define your loss function and optimizers
criterion = nn.CrossEntropyLoss()
optimizer_A = optim.Adam(model_A.parameters(), lr=0.001)
optimizer_B = optim.Adam(model_B.parameters(), lr=0.001)

# Train the models
num_epochs = 10
for epoch in range(num_epochs):
    for i, (inputs, targets) in enumerate(trainloader):
        # Zero the parameter gradients

        # Forward pass for both models
        outputs_A = model_A(inputs)
        outputs_B = model_B(inputs)

        # Calculate the loss for both models
        loss_A = criterion(outputs_A, targets)
        loss_B = criterion(outputs_B, targets)

        # Combine the losses
        combined_loss = loss_A + loss_B

        # Backward pass and optimization

        # Print the losses
        print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(trainloader)}], Loss A: {loss_A.item():.4f}, Loss B: {loss_B.item():.4f}')

1 Like

Thank you so much @AbdulsalamBande