Pytorch TPU printing 1 epoch performance 8 times

is the code attached below bug free?

def train_model():
    global train_dataset, valid_dataset
    
    torch.manual_seed(42)
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=train_sampler,
        num_workers=0,
        drop_last=True) # print(len(train_loader))
    
    '''valid_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        )'''
        
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=BATCH_SIZE,
        #sampler=valid_sampler,
        shuffle=False,
        num_workers=0,
        drop_last=True)
    
    #xm.master_print(f"Train for {len(train_loader)} steps per epoch")
    LOGGER.debug(f"Train for {len(train_loader)} steps per epoch")
    # Scale learning rate to num cores
    lr  = 0.0001 * xm.xrt_world_size()

    # Get loss function, optimizer, and model
    device = xm.xla_device()

    #model = model()
    '''
    for param in model.base_model.parameters(): # freeze some layers
        param.requires_grad = False'''
    
    
    global model
    
    model = model.to(device)

    criterion = torch.nn.BCEWithLogitsLoss() #  MSELoss
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
    scheduler = OneCycleLR(optimizer, 
                           lr, 
                           div_factor=10.0, 
                           final_div_factor=50.0, 
                           epochs=NUM_EPOCH,
                           steps_per_epoch=len(train_loader))
    
    
    
    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        
        #xm.master_print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        LOGGER.debug('Epoch {}/{}'.format(epoch, num_epochs - 1))
        #xm.master_print('-' * 10)
        LOGGER.debug('-' * 10)
        scheduler.step()
        
        running_loss = 0.0
        tk0 = tqdm(loader, total=int(len(train_loader)))
        counter = 0
        for bi, d in enumerate(tk0):
            inputs = d["image"]
            labels = d["label"].view(-1, 1)
            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)
            optimizer.zero_grad()
            #with torch.set_grad_enabled(True):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            #loss = criterion(outputs, torch.max(labels, 1)[1])
            loss.backward()
            xm.optimizer_step(optimizer)
            running_loss += loss.item() * inputs.size(0)
            #print(running_loss)
            counter += 1
            tk0.set_postfix(loss=(running_loss / (counter * BATCH_SIZE)))
        epoch_loss = running_loss / len(train_loader)
        #xm.master_print('Training Loss: {:.8f}'.format(epoch_loss))
        LOGGER.debug('Training Loss: {:.8f}'.format(epoch_loss))

                
    def test_loop_fn(loader):
        tk0 = tqdm(loader, total=int(len(valid_loader)))
        counter = 0
        total_samples, correct = 0, 0
        for bi, d in enumerate(tk0):
            inputs = d["image"]
            labels = d["label"].view(-1, 1)
            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)
            optimizer.zero_grad()
            
            with torch.no_grad():
                
                output = model(inputs)
                
                pred = output.max(1, keepdim=True)[1]
                correct += pred.eq(labels.view_as(pred)).sum().item()
                total_samples += inputs.size()[0]
        accuracy = 100.0 * correct / total_samples
        #print('[xla:{}] Accuracy={:.4f}%'.format(xm.get_ordinal(), accuracy), flush=True)
        model.train()
        return accuracy

    # Train - valid  loop
    accuracy = []
    for epoch in range(1, num_epochs + 1):
        start = time.time()
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        
        para_loader = pl.ParallelLoader(valid_loader, [device])
        accuracy.append(test_loop_fn(para_loader.per_device_loader(device)))
        #xm.master_print("Finished training epoch {}  Val-Acc {:.4f} in {:.4f} sec".format(epoch, accuracy[-1],   time.time() - start))        
        
        LOGGER.debug("Finished training epoch {}  Val-Acc {:.4f} in {:.4f} sec".format(epoch, accuracy[-1],   time.time() - start))   
        valauc = accuracy[-1]
        if(epoch>4):
            xm.save(model.state_dict(), f"./epoch{epoch}valauc{valauc}.bin")
    return accuracy

def _mp_fn(rank, flags):
    global acc_list
    torch.set_default_tensor_type('torch.FloatTensor')
    res = train_model()

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

1st epochs train log looks like this :

2020-05-09 12:21:29,371 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:29,710 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:29,721 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:29,911 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:30,561 DEBUG Epoch 1/6
2020-05-09 12:21:30,564 DEBUG ----------
2020-05-09 12:21:31,065 DEBUG Epoch 1/6
2020-05-09 12:21:31,076 DEBUG ----------
2020-05-09 12:21:31,120 DEBUG Epoch 1/6
2020-05-09 12:21:31,130 DEBUG ----------
2020-05-09 12:21:31,390 DEBUG Epoch 1/6
2020-05-09 12:21:31,426 DEBUG ----------
2020-05-09 12:21:32,629 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:33,573 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:33,748 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:33,883 DEBUG Train for 445 steps per epoch
2020-05-09 12:21:34,889 DEBUG Epoch 1/6
2020-05-09 12:21:34,914 DEBUG ----------
2020-05-09 12:21:35,573 DEBUG Epoch 1/6
2020-05-09 12:21:35,613 DEBUG ----------
2020-05-09 12:21:35,823 DEBUG Epoch 1/6
2020-05-09 12:21:35,845 DEBUG ----------
2020-05-09 12:21:36,128 DEBUG Epoch 1/6
2020-05-09 12:21:36,171 DEBUG ----------
2020-05-09 12:35:08,162 DEBUG Training Loss: 11.22450873
2020-05-09 12:35:08,172 DEBUG Training Loss: 11.19612112
2020-05-09 12:35:08,309 DEBUG Training Loss: 11.18398799
2020-05-09 12:35:08,352 DEBUG Training Loss: 11.16665337
2020-05-09 12:35:08,362 DEBUG Training Loss: 11.20103131
2020-05-09 12:35:08,357 DEBUG Training Loss: 11.19919075
2020-05-09 12:35:08,368 DEBUG Training Loss: 11.19310062
2020-05-09 12:35:08,386 DEBUG Training Loss: 11.21970569
2020-05-09 12:39:31,562 DEBUG Finished training epoch 1 Val-Acc 50.5348 in 1080.4523 sec

the validation accuracy calculation is slow,somehow validation accuracy is using 1 core for calculation where in train phase it is using 8 cores,how do i solve this issue? i need to make the validation calculation fast, also in all epoch i see same validation accuracy, maybe i have bug in my code? another thing is,if i train this model for 8-10 epoch then kaggle kernel doesn’t finish commit,it gives error that’s not visible,so maybe somewhere in my code i am requesting more memory and getting OOM for that? also in my code if i try sampler=valid_sampler for valid_loader then i get error. please help me find bugs in my code,thank you a lot in advance

FYI this conversation has been moved to https://github.com/pytorch/xla/issues/2054.

1 Like