Cannot replicate if number of devices (1) is different from 8 - TPU

Hello all,
I am just getting started with TPU on a Kaggle notebook, and running into the following error - which mainly says as -

Exception in device=TPU:7: Cannot replicate if number of devices (1) is different from 8

where device=TPU:i where I goes from 0 to 7.
My trainer loop, run and the spawn function are as follows -

import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

def lossReduction(vals):
    return sum(vals)/len(vals)
    
def xlaTrainer(dataLoader, model, optimizer, scheduler=None, device=None, min_loss = 1e9):
    model.train()
    for batch_id, data in enumerate(dataLoader):
        
        image = data['image'].to(device)
        mask = data['ground_truth'].to(device)
        
        optimizer.zero_grad()
        loss = model((image, mask))
        
        if batch_id % 50 == 0:
            loss_reduces = xm.mesh_reduce('loss_reduce', loss, lossReduction)
            xm.master_print(f'Loss = {loss_reduced}')
            
        loss.backward()
        xm.optimizer_step(optimizer)
        if scheduler is not None:
            scheduler.step()
        
            

def run():
    device = xm.xla_device()
    batch_size=32
    epochs = 100
    min_loss = 1e9
    common_path = '/kaggle/input/dut-omron/DUT OMROM/'
    images = 'DUT-OMRON-image/DUT-OMRON-image/'
    masks = 'DUT-OMRON-gt-pixelwise.zip/pixelwiseGT-new-PNG/'

    image_path = common_path + images
    map_path = common_path + masks

    model = SegaNet(backbone='resnet50', num_classes=1).train().to(device)
    train_dataset = DUT_OMRON(image_path=images, image_map_path=map_path)
    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=xm.xrt.world_size(), rank=xm.get_ordinal, shuffle=True)
    train_dataloader = Dataloader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=0)

    parameters = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    lr = 3e-4 * xm.xrt_world_size()
    num_train_steps = int(len(train_datasets)/batch_size/xm.xrt_world_size() * epochs)

    optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr)
    
    for epoch in range(epochs):
        gc.collect()
        parallel_loader = pl.ParallelLoader(train_dataloader, [device])
        gc.collect()
        
        xlaTrainer(parallel_loader.per_device_loader(device), model, optimizer, device=device)
        del parallel_loader
        gc.collect()
        xm.master_print("Here")
        xm.save(model.state_dict(), "dut_omron.bin")
        
        
        

def spawn_function(rank, flags):
    a = run()
        
os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '10000000000'
FLAGS = {}
xmp.spawn(spawn_function, args=(), nprocs=8, start_method='fork')

The TPU version available on kaggle is TPUv3-8.
Where am I going wrong ?
TIA

I tried to run your code on a google cloud vm but I did not see the error you mentioned( I have to remove flags argument in your spawn_function since you pass args=() ). This kind of problem usually caused by your code calling XLA functions at global scope, together with multi-processing, which is not allowed. I am guessing your did that prior to this code and the machine is being left in a bad state? Restarting the kaggle kernel might help.

Issues · pytorch/xla · GitHub is a better place to ask pt/xla related question. If you can share the kaggle notebook we might be able to help you to see what is going wrong.