Hello,
Let’s say that I have 2 models A & B to train on one task, and I train both models at the same time.
For each epoch (or each dataloader batch), if model A completes training faster than model B, how can I write a statement so that model A can wait for model B and a loss is updated (loss.backward()) based on both models instead of either A or B. Is it necessary to write a wait statement, like sleep(), in the training loop? Thanks!
Best,
Tommy
In PyTorch, you can achieve this by using the built-in torch.utils.data.DataLoader
to handle your data and training both models in the same loop. You don’t need to use a wait statement or sleep()
in this case. Instead, you can compute the loss for both models in the same loop and update the gradients based on the combined loss. Here’s an example of how to do this:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Load your dataset and create DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
# Define your models A and B
model_A = ...
model_B = ...
# Define your loss function and optimizers
criterion = nn.CrossEntropyLoss()
optimizer_A = optim.Adam(model_A.parameters(), lr=0.001)
optimizer_B = optim.Adam(model_B.parameters(), lr=0.001)
# Train the models
num_epochs = 10
for epoch in range(num_epochs):
for i, (inputs, targets) in enumerate(trainloader):
# Zero the parameter gradients
optimizer_A.zero_grad()
optimizer_B.zero_grad()
# Forward pass for both models
outputs_A = model_A(inputs)
outputs_B = model_B(inputs)
# Calculate the loss for both models
loss_A = criterion(outputs_A, targets)
loss_B = criterion(outputs_B, targets)
# Combine the losses
combined_loss = loss_A + loss_B
# Backward pass and optimization
combined_loss.backward()
optimizer_A.step()
optimizer_B.step()
# Print the losses
print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(trainloader)}], Loss A: {loss_A.item():.4f}, Loss B: {loss_B.item():.4f}')
1 Like
Thank you so much @AbdulsalamBande