Error in xm.optimizer_step()

Hi !
I am quite beginner to pytorch_xla. In the process of getting familiar with it, i have been testing basic code to train a model on CIFAR 10 and trying it train on my cloud TPU v4-8 with tpu-vm-v4-pt-1.13 . I have been getting error at optimizer step. I tried searching online for the solution, but unable to find one.
torch : 1.13.0+cu117
torch_xla: 1.13
device: “xla:0”

please look over the code and error

code:

# Import libraries
import os
os.environ['PJRT_DEVICE']='TPU'
import torch
from torch.utils.data import DataLoader
import torch_xla.core.xla_model as xm
from torchvision import datasets, transforms

device = xm.xla_device()


# Define data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load CIFAR-10 dataset
train_data = datasets.CIFAR10('~/.torch/datasets', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)

# Define model class (same as previous example)
class MNISTModel(torch.nn.Module):
    def __init__(self):
        super(MNISTModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 6, 5)  # Adjust for CIFAR-10 channels
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 10)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Create PyTorch model and optimizer
model = MNISTModel().to(device)
optimizer = torch.optim.Adam(model.parameters())

# Loss function (CrossEntropyLoss for multi-class classification)
loss_fn = torch.nn.CrossEntropyLoss()

i=0
# Training loop
for epoch in range(5):
  for data, target in train_loader:
    
    data, target = data.to(device), target.to(device)
    
    optimizer.zero_grad()
    output = model(data)
    #print(output.shape)
    #print(target.shape)
    
    loss = loss_fn(output, target)
    loss.backward()
    #optimizer.step()
    xm.optimizer_step(optimizer)  # Use XLA optimizer step
    if (i+1) % 100 == 0:
        print(f"Epoch: {epoch+1} [{i+1}/{len(train_loader)} ({100. * (i+1) / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")
    i+=1


       
print('Finished Training')

Error:

ValueError                                Traceback (most recent call last)
Cell In[7], line 65
     63 loss.backward()
     64 #optimizer.step()
---> 65 xm.optimizer_step(optimizer)  # Use XLA optimizer step
     66 if (i+1) % 100 == 0:
     67     print(f"Epoch: {epoch+1} [{i+1}/{len(train_loader)} ({100. * (i+1) / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

File /usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py:1021, in optimizer_step(optimizer, barrier, optimizer_args, groups, pin_layout)
    992 def optimizer_step(optimizer,
    993                    barrier=False,
    994                    optimizer_args={},
    995                    groups=None,
    996                    pin_layout=True):
    997   """Run the provided optimizer step and issue the XLA device step computation.
    998 
    999   Args:
   (...)
   1019     The same value returned by the `optimizer.step()` call.
   1020   """
-> 1021   reduce_gradients(optimizer, groups=groups, pin_layout=pin_layout)
   1022   loss = optimizer.step(**optimizer_args)
   1023   if barrier:

File /usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py:979, in reduce_gradients(optimizer, groups, pin_layout)
    965 def reduce_gradients(optimizer, groups=None, pin_layout=True):
    966   """Reduces all the gradients handled by an optimizer.
    967 
    968   Args:
   (...)
    977       See `xm.all_reduce` for details.
    978   """
--> 979   cctx = CollectiveContext()
    980   count = max(cctx.replica_devcount, cctx.world_size)
    981   if count > 1:

File /usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py:78, in CollectiveContext.__init__(self, groups)
     76     self.requires_intercore_reduce = self.replica_devcount > 1
     77     if self.requires_interhost_reduce:
---> 78       self.interhost_group, ranks = _make_interhost_group(
     79           self.replica_devcount, self.world_size)
     80       self.is_reduce_host = self.ordinal in ranks
     81 else:
     82   # Standard replication path.

File /usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py:113, in _make_interhost_group(replica_devcount, world_size)
    107 def _make_interhost_group(replica_devcount, world_size):
    108   # Every host in a sea-of-devices case handles replica_devcount devices.
    109   # The replica device index 0 of each host does the inter-host replication
    110   # using torch.distributed.
    111   # The XLA CPU is a special case where there is one process per XLA CPU device,
    112   # which is also a virtual host within a physical host.
--> 113   ranks = tuple(range(0, world_size, replica_devcount))
    114   return _get_torch_dist_group(ranks), ranks

ValueError: range() arg 3 must not be zero