Hello all,
I am just getting started with TPU on a Kaggle notebook, and running into the following error - which mainly says as -
Exception in device=TPU:7: Cannot replicate if number of devices (1) is different from 8
where device=TPU:i
where I goes from 0 to 7.
My trainer loop, run and the spawn function are as follows -
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
def lossReduction(vals):
return sum(vals)/len(vals)
def xlaTrainer(dataLoader, model, optimizer, scheduler=None, device=None, min_loss = 1e9):
model.train()
for batch_id, data in enumerate(dataLoader):
image = data['image'].to(device)
mask = data['ground_truth'].to(device)
optimizer.zero_grad()
loss = model((image, mask))
if batch_id % 50 == 0:
loss_reduces = xm.mesh_reduce('loss_reduce', loss, lossReduction)
xm.master_print(f'Loss = {loss_reduced}')
loss.backward()
xm.optimizer_step(optimizer)
if scheduler is not None:
scheduler.step()
def run():
device = xm.xla_device()
batch_size=32
epochs = 100
min_loss = 1e9
common_path = '/kaggle/input/dut-omron/DUT OMROM/'
images = 'DUT-OMRON-image/DUT-OMRON-image/'
masks = 'DUT-OMRON-gt-pixelwise.zip/pixelwiseGT-new-PNG/'
image_path = common_path + images
map_path = common_path + masks
model = SegaNet(backbone='resnet50', num_classes=1).train().to(device)
train_dataset = DUT_OMRON(image_path=images, image_map_path=map_path)
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=xm.xrt.world_size(), rank=xm.get_ordinal, shuffle=True)
train_dataloader = Dataloader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=0)
parameters = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
lr = 3e-4 * xm.xrt_world_size()
num_train_steps = int(len(train_datasets)/batch_size/xm.xrt_world_size() * epochs)
optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=lr)
for epoch in range(epochs):
gc.collect()
parallel_loader = pl.ParallelLoader(train_dataloader, [device])
gc.collect()
xlaTrainer(parallel_loader.per_device_loader(device), model, optimizer, device=device)
del parallel_loader
gc.collect()
xm.master_print("Here")
xm.save(model.state_dict(), "dut_omron.bin")
def spawn_function(rank, flags):
a = run()
os.environ['XLA_USE_BF16']="1"
os.environ['XLA_TENSOR_ALLOCATOR_MAXSIZE'] = '10000000000'
FLAGS = {}
xmp.spawn(spawn_function, args=(), nprocs=8, start_method='fork')
The TPU version available on kaggle is TPUv3-8.
Where am I going wrong ?
TIA