Cannot replicate if number of devices (1) is different from 8 - TPU

Hello all,
I am just getting started with TPU on a Kaggle notebook, and running into the following error - which mainly says as -

Exception in device=TPU:7: Cannot replicate if number of devices (1) is different from 8

where device=TPU:i where I goes from 0 to 7.
My trainer loop, run and the spawn function are as follows -

import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils

def lossReduction(vals):
    return sum(vals)/len(vals)
def xlaTrainer(dataLoader, model, optimizer, scheduler=None, device=None, min_loss = 1e9):
    for batch_id, data in enumerate(dataLoader):
        image = data['image'].to(device)
        mask = data['ground_truth'].to(device)
        loss = model((image, mask))
        if batch_id % 50 == 0:
            loss_reduces = xm.mesh_reduce('loss_reduce', loss, lossReduction)
            xm.master_print(f'Loss = {loss_reduced}')
        if scheduler is not None:

def run():
    device = xm.xla_device()
    epochs = 100
    min_loss = 1e9
    common_path = '/kaggle/input/dut-omron/DUT OMROM/'
    images = 'DUT-OMRON-image/DUT-OMRON-image/'
    masks = ''

    image_path = common_path + images
    map_path = common_path + masks

    model = SegaNet(backbone='resnet50', num_classes=1).train().to(device)
    train_dataset = DUT_OMRON(image_path=images, image_map_path=map_path)
    sampler =, num_replicas=xm.xrt.world_size(), rank=xm.get_ordinal, shuffle=True)
    train_dataloader = Dataloader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=0)

    parameters = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    lr = 3e-4 * xm.xrt_world_size()
    num_train_steps = int(len(train_datasets)/batch_size/xm.xrt_world_size() * epochs)

    optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr)
    for epoch in range(epochs):
        parallel_loader = pl.ParallelLoader(train_dataloader, [device])
        xlaTrainer(parallel_loader.per_device_loader(device), model, optimizer, device=device)
        del parallel_loader
        xm.master_print("Here"), "dut_omron.bin")

def spawn_function(rank, flags):
    a = run()
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '10000000000'
FLAGS = {}
xmp.spawn(spawn_function, args=(), nprocs=8, start_method='fork')

The TPU version available on kaggle is TPUv3-8.
Where am I going wrong ?