Unexpected bevaiour with torch xla

When using TPU device, my parameters are not updated while it functions correctly when run on a CPU device.
Folowing code replicates the problem.

if True:
    import torch_xla
    import torch_xla.core.xla_model as xm 
    os.environ["XRT_TPU_CONFIG"] = "tpu_worker;0;xxxx.xxxxxx:8470"
    dev = xm.xla_device()
else:
    dev = torch.device("cpu")

class Test(torch.nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.model = torch.nn.Sequential(torch.nn.Linear(5, 2))
        self.sms = [torch.nn.Linear(2, 2) for _ in range(2)]
        for smi, sm in enumerate(self.sms):
            self.register_parameter("sm_%d_wt" % smi, sm.weight)
            self.register_parameter("sm_%d_bias" % smi, sm.bias)
    
    def forward(self, x):
        reprs = self.model(x)
        logits = [sm(reprs) for sm in self.sms]
        return torch.stack(logits, dim=1)
    
    def loss(self, logits, labels):
        logits = torch.reshape(logits, [-1, 2])
        labels = torch.reshape(labels, [-1])
        print (logits.shape, labels.shape)
        loss = torch.nn.functional.cross_entropy(logits, labels)
        return torch.mean(loss)
    
    def to(self, dev):
        self.sms = [sm.to(dev) for sm in self.sms]
        self.model = self.model.to(dev)
        return self
    
test = Test()
test = test.to(dev)
print ([n for n, _ in list(test.named_parameters())])
b = torch.distributions.Binomial(1, torch.tensor([0.2, 0.8]))
x, y = torch.randn([10, 5], device=dev), b.sample((10, )).type(torch.int64).to(dev)
logits = test(x)
optimizer = torch.optim.SGD(test.parameters(), 1e-1)
loss = test.loss(logits, y)
print (test.sms[0].weight, test.sms[1].weight)
optimizer.zero_grad()
loss.backward()
if dev.type=='xla':
    xm.optimizer_step(optimizer, barrier=True)
else:
    optimizer.step()
print (test.sms[0].weight, test.sms[1].weight)

Output when using xla device. Note the step does not update parameters:

['sm_0_wt', 'sm_0_bias', 'sm_1_wt', 'sm_1_bias', 'model.0.weight', 'model.0.bias']
torch.Size([20, 2]) torch.Size([20])
Parameter containing:
tensor([[-0.6274,  0.3177],
        [-0.1122,  0.3985]], device='xla:1', requires_grad=True) Parameter containing:
tensor([[ 0.3345, -0.6810],
        [-0.5474, -0.4858]], device='xla:1', requires_grad=True)
Parameter containing:
tensor([[-0.6274,  0.3177],
        [-0.1122,  0.3985]], device='xla:1', requires_grad=True) Parameter containing:
tensor([[ 0.3345, -0.6810],
        [-0.5474, -0.4858]], device='xla:1', requires_grad=True)

Output when using cpu device. Note the step does update parameters

['sm_0_wt', 'sm_0_bias', 'sm_1_wt', 'sm_1_bias', 'model.0.weight', 'model.0.bias']
torch.Size([20, 2]) torch.Size([20])
Parameter containing:
tensor([[ 0.0524,  0.6992],
        [ 0.5904, -0.3944]], requires_grad=True) Parameter containing:
tensor([[ 0.4875,  0.0720],
        [ 0.4940, -0.7028]], requires_grad=True)
Parameter containing:
tensor([[ 0.0513,  0.6941],
        [ 0.5915, -0.3894]], requires_grad=True) Parameter containing:
tensor([[ 0.4982,  0.0738],
        [ 0.4833, -0.7046]], requires_grad=True)

What am I doing wrong?
torch_xla, torch version: 1.6.

I can fix the problem by replacing the initialization with the following definition.

 def __init__(self):
        super(Test, self).__init__()
        self.model = torch.nn.Sequential(torch.nn.Linear(5, 2)).to(dev)
        self.sms = torch.nn.ModuleList([torch.nn.Linear(2, 2) for _ in range(2)])
# redacted
# def to(self): ... 

Presumably, my earlier self.to is somehow detaching the parameter and interestingly revelas itself only on XLA. I would still very much like to understand what was I doing wrong earlier.