Is this the correct way to do ImageNet training on Torch XLA?

I’m new to TorchXLA and I decided to try training a model on ImageNet on Kaggle TPU. Following a few tutorials and adapting them to my needs, I came up with this code for the training loop:

def train(device_id, flags):
        device = xm.xla_device()
        rank = xm.get_local_ordinal()

        batch_size = BATCH_SIZE // xm.xrt_world_size()
        steps_per_epoch = DS_SIZE // (batch_size * xm.xrt_world_size())
        # Build datasets
        train_loader, val_loader = get_dataloaders(steps_per_epoch, batch_size, device)
        xm.rendezvous("loaded dataset")
    
        model = MyModel()
        model = model.to(device)
        xm.broadcast_master_param(model)
        model_params = sum(p.numel() for p in model.parameters())

        xm.master_print(f'Model parameters: {model_params:,d}')
    
        # Set up the optimizer
        optimizer = optim.AdamW(
            model.parameters(),
            lr=LR,
            weight_decay=WD
        )
    
        scheduler = cosine_scheduler_with_warmup(
            optimizer,
            total_epochs=EPOCHS,
            steps_per_epoch=math.ceil(steps_per_epoch / GRAD_ACC_STEPS),
            warmup_epochs=WARMUP_EPOCHS,
            initial_lr=0.01,
            end_lr=0.001
        )
    
        xm.rendezvous("loaded model and optimizer")

        start_epoch, train_history, test_history = SERIAL_EXEC.run(lambda: checkpoint_load(model, optimizer, scheduler))
        
        xm.rendezvous("loaded weights")
        xm.master_print("training begins")
    
        for epoch in range(start_epoch, start_epoch + EPOCHS):
            xm.master_print(f"Starting epoch {epoch + 1}")
            model.train()
            total_loss = torch.zeros((), device=device)
            local_total_batches = 0
            for step, (images, labels, _) in zip(range(1, steps_per_epoch+1), train_loader):
                images = images.to(device)
                labels = labels.to(device)
                
                # Get predictions
                outputs = model(images)

                # Compute distillation loss
                loss = loss_fn(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                # Do optimizer step with gradient clipping
                if step % GRAD_ACC_STEPS == 0 or step == steps_per_epoch:
                    xm.reduce_gradients(optimizer, pin_layout=True)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                    optimizer.step()
                    scheduler.step()
                xm.mark_step()
                with torch.no_grad():
                    total_loss += loss * images.size(0)
                    local_total_batches += images.size(0)
    
            # Aggregate metrics across devices
            global_loss = xm.mesh_reduce("total_loss", total_loss.item(), sum)
            global_batches = xm.mesh_reduce("total_batches", local_total_batches, sum)
    
            average_loss = global_loss / global_batches
    
            xm.master_print(f"Epoch [{epoch+1}/{EPOCHS}], Training Loss: {average_loss:.4f}")
    
            # Evaluation loop
            val_loss, val_accuracy = eval_on_val(val_loader, model, device)
            xm.master_print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}%")
    
            train_history.append([average_loss])
            test_history.append([val_loss, val_accuracy])
    
            xm.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'history': [train_history, test_history],
                'epoch': epoch + 1
            }, 'checkpoint.pth')
    
        xm.master_print("Training complete")

And the eval function:

def eval_on_val(val_loader, model, device):
    model.eval()
    val_loss = torch.zeros((), device=device)
    correct = torch.zeros((), device=device)
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss * labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).float().sum()
            total += labels.size(0)

    global_val_loss = xm.mesh_reduce("val_loss", val_loss.item(), sum)
    global_correct = xm.mesh_reduce("val_correct", correct.item(), sum)
    global_total = xm.mesh_reduce("val_total", total, sum)

    avg_val_loss = global_val_loss / global_total
    val_accuracy = 100.0 * global_correct / global_total

    return avg_val_loss, val_accuracy

I have a few concerns with this code:

  • Will the accuracy and loss be aggregated correctly?
  • Will there be an issue with the gradient accumulations?
  • Will the model need to be recompiled every epoch because of the eval function?
  • Did I use xm.broadcast_master_param() correctly?