I’m new to TorchXLA and I decided to try training a model on ImageNet on Kaggle TPU. Following a few tutorials and adapting them to my needs, I came up with this code for the training loop:
def train(device_id, flags):
device = xm.xla_device()
rank = xm.get_local_ordinal()
batch_size = BATCH_SIZE // xm.xrt_world_size()
steps_per_epoch = DS_SIZE // (batch_size * xm.xrt_world_size())
# Build datasets
train_loader, val_loader = get_dataloaders(steps_per_epoch, batch_size, device)
xm.rendezvous("loaded dataset")
model = MyModel()
model = model.to(device)
xm.broadcast_master_param(model)
model_params = sum(p.numel() for p in model.parameters())
xm.master_print(f'Model parameters: {model_params:,d}')
# Set up the optimizer
optimizer = optim.AdamW(
model.parameters(),
lr=LR,
weight_decay=WD
)
scheduler = cosine_scheduler_with_warmup(
optimizer,
total_epochs=EPOCHS,
steps_per_epoch=math.ceil(steps_per_epoch / GRAD_ACC_STEPS),
warmup_epochs=WARMUP_EPOCHS,
initial_lr=0.01,
end_lr=0.001
)
xm.rendezvous("loaded model and optimizer")
start_epoch, train_history, test_history = SERIAL_EXEC.run(lambda: checkpoint_load(model, optimizer, scheduler))
xm.rendezvous("loaded weights")
xm.master_print("training begins")
for epoch in range(start_epoch, start_epoch + EPOCHS):
xm.master_print(f"Starting epoch {epoch + 1}")
model.train()
total_loss = torch.zeros((), device=device)
local_total_batches = 0
for step, (images, labels, _) in zip(range(1, steps_per_epoch+1), train_loader):
images = images.to(device)
labels = labels.to(device)
# Get predictions
outputs = model(images)
# Compute distillation loss
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
# Do optimizer step with gradient clipping
if step % GRAD_ACC_STEPS == 0 or step == steps_per_epoch:
xm.reduce_gradients(optimizer, pin_layout=True)
torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
optimizer.step()
scheduler.step()
xm.mark_step()
with torch.no_grad():
total_loss += loss * images.size(0)
local_total_batches += images.size(0)
# Aggregate metrics across devices
global_loss = xm.mesh_reduce("total_loss", total_loss.item(), sum)
global_batches = xm.mesh_reduce("total_batches", local_total_batches, sum)
average_loss = global_loss / global_batches
xm.master_print(f"Epoch [{epoch+1}/{EPOCHS}], Training Loss: {average_loss:.4f}")
# Evaluation loop
val_loss, val_accuracy = eval_on_val(val_loader, model, device)
xm.master_print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}%")
train_history.append([average_loss])
test_history.append([val_loss, val_accuracy])
xm.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'history': [train_history, test_history],
'epoch': epoch + 1
}, 'checkpoint.pth')
xm.master_print("Training complete")
And the eval function:
def eval_on_val(val_loader, model, device):
model.eval()
val_loss = torch.zeros((), device=device)
correct = torch.zeros((), device=device)
total = 0
with torch.no_grad():
for images, labels in val_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss * labels.size(0)
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).float().sum()
total += labels.size(0)
global_val_loss = xm.mesh_reduce("val_loss", val_loss.item(), sum)
global_correct = xm.mesh_reduce("val_correct", correct.item(), sum)
global_total = xm.mesh_reduce("val_total", total, sum)
avg_val_loss = global_val_loss / global_total
val_accuracy = 100.0 * global_correct / global_total
return avg_val_loss, val_accuracy
I have a few concerns with this code:
- Will the accuracy and loss be aggregated correctly?
- Will there be an issue with the gradient accumulations?
- Will the model need to be recompiled every epoch because of the eval function?
- Did I use
xm.broadcast_master_param()
correctly?