Hi !
I am quite beginner to pytorch_xla. In the process of getting familiar with it, i have been testing basic code to train a model on CIFAR 10 and trying it train on my cloud TPU v4-8 with tpu-vm-v4-pt-1.13 . I have been getting error at optimizer step. I tried searching online for the solution, but unable to find one.
torch : 1.13.0+cu117
torch_xla: 1.13
device: “xla:0”
please look over the code and error
code:
# Import libraries
import os
os.environ['PJRT_DEVICE']='TPU'
import torch
from torch.utils.data import DataLoader
import torch_xla.core.xla_model as xm
from torchvision import datasets, transforms
device = xm.xla_device()
# Define data transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# Load CIFAR-10 dataset
train_data = datasets.CIFAR10('~/.torch/datasets', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
# Define model class (same as previous example)
class MNISTModel(torch.nn.Module):
def __init__(self):
super(MNISTModel, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 6, 5) # Adjust for CIFAR-10 channels
self.pool = torch.nn.MaxPool2d(2, 2)
self.conv2 = torch.nn.Conv2d(6, 16, 5)
self.fc1 = torch.nn.Linear(16 * 5 * 5, 120)
self.fc2 = torch.nn.Linear(120, 84)
self.fc3 = torch.nn.Linear(84, 10)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# Create PyTorch model and optimizer
model = MNISTModel().to(device)
optimizer = torch.optim.Adam(model.parameters())
# Loss function (CrossEntropyLoss for multi-class classification)
loss_fn = torch.nn.CrossEntropyLoss()
i=0
# Training loop
for epoch in range(5):
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
#print(output.shape)
#print(target.shape)
loss = loss_fn(output, target)
loss.backward()
#optimizer.step()
xm.optimizer_step(optimizer) # Use XLA optimizer step
if (i+1) % 100 == 0:
print(f"Epoch: {epoch+1} [{i+1}/{len(train_loader)} ({100. * (i+1) / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")
i+=1
print('Finished Training')
Error:
ValueError Traceback (most recent call last)
Cell In[7], line 65
63 loss.backward()
64 #optimizer.step()
---> 65 xm.optimizer_step(optimizer) # Use XLA optimizer step
66 if (i+1) % 100 == 0:
67 print(f"Epoch: {epoch+1} [{i+1}/{len(train_loader)} ({100. * (i+1) / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")
File /usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py:1021, in optimizer_step(optimizer, barrier, optimizer_args, groups, pin_layout)
992 def optimizer_step(optimizer,
993 barrier=False,
994 optimizer_args={},
995 groups=None,
996 pin_layout=True):
997 """Run the provided optimizer step and issue the XLA device step computation.
998
999 Args:
(...)
1019 The same value returned by the `optimizer.step()` call.
1020 """
-> 1021 reduce_gradients(optimizer, groups=groups, pin_layout=pin_layout)
1022 loss = optimizer.step(**optimizer_args)
1023 if barrier:
File /usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py:979, in reduce_gradients(optimizer, groups, pin_layout)
965 def reduce_gradients(optimizer, groups=None, pin_layout=True):
966 """Reduces all the gradients handled by an optimizer.
967
968 Args:
(...)
977 See `xm.all_reduce` for details.
978 """
--> 979 cctx = CollectiveContext()
980 count = max(cctx.replica_devcount, cctx.world_size)
981 if count > 1:
File /usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py:78, in CollectiveContext.__init__(self, groups)
76 self.requires_intercore_reduce = self.replica_devcount > 1
77 if self.requires_interhost_reduce:
---> 78 self.interhost_group, ranks = _make_interhost_group(
79 self.replica_devcount, self.world_size)
80 self.is_reduce_host = self.ordinal in ranks
81 else:
82 # Standard replication path.
File /usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py:113, in _make_interhost_group(replica_devcount, world_size)
107 def _make_interhost_group(replica_devcount, world_size):
108 # Every host in a sea-of-devices case handles replica_devcount devices.
109 # The replica device index 0 of each host does the inter-host replication
110 # using torch.distributed.
111 # The XLA CPU is a special case where there is one process per XLA CPU device,
112 # which is also a virtual host within a physical host.
--> 113 ranks = tuple(range(0, world_size, replica_devcount))
114 return _get_torch_dist_group(ranks), ranks
ValueError: range() arg 3 must not be zero