Input tensor is not an XLA tensor: torch.FloatTensor

I tried implementing an autoencoder model using a custom dataset on tpu via PyTorch xla.
It was showing me error as Input tensor is not an XLA tensor: torch.FloatTensor so i changed all the variables to tpu device and I still can’t figure out why is it showing so.

Below is a snippet of error and it seems that the same snippet is repeating for different tpu’s

Traceback (most recent call last):
File “/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py”, line 119, in _start_fn
fn(gindex, *args)
File “”, line 94, in map_fn
train_losses,val_losses = train(100,train_loader,test_loader,criterion,device)
File “”, line 21, in train
outputs = model(img_grey)
File “/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py”, line 550, in call
result = self.forward(*input, **kwargs)
File “”, line 41, in forward
midlevel_features = self.midlevel_resnet(input)

this also
Exception in device=TPU:5: torch_xla/csrc/aten_xla_bridge.cpp:69 : Check failed: xtensor

Train Function

def train(n_epochs,data_loader_train,data_loader_val,criterion,device):
  best_losses=1e10 
  model.train()  
  train_losses=[]
  val_losses=[]

  ## Trains
  train_start = time.time()
  for epoch in range(1, flags['num_epochs']+1):
 
      data_loss = 0.0
      total =0 

      time_start=time.time()
      para_train_loader = pl.ParallelLoader(data_loader_train, [device]).per_device_loader(device)
      for i,data in enumerate(para_train_loader):

          img_lab,img_original,img_grey,img_ab,target = data
          img_lab,img_original,img_grey,img_ab,target= img_lab.to(device),img_original.to(device),img_grey.to(device),img_ab.to(device),target.to(device)
        
          outputs = model(img_grey)
          loss = criterion(img_ab,outputs)

          optimizer.zero_grad()
          loss.backward()
          # optimizer.step()
          xm.optimizer_step(optimizer)

          data_loss += loss.item()
          total += 1         
         
          if i % 50 == 0:
            print('Epoch: {} \tIteration: {} Training Loss: {:.6f} \tTime Taken :{:.3f}'.format(
          epoch, 
          i,
          data_loss/total,
          time.time()-time_start
          ))
 

      data_loss = data_loss/total
      print('Process: {} \tEpoch: {} \tTraining Loss: {:.6f} \tTime Taken :{:.3f}'.format(
          index,
          epoch, 
          data_loss,
          time.time() - time_start
          ))
      train_losses.append(data_loss) 
      losses = validation(data_loader_val, model, criterion,epoch,device)
      val_losses.append(losses)
          # Save checkpoint and replace old best model if current model is better
      if losses < best_losses:
        best_losses = losses
        print('=====Saving Best Model========')
        torch.save(model.state_dict(), 'checkpoints/tpu_{}_resent_model-epoch-{}-losses-{:.3f}.pth'.format(index,epoch,losses)) 
  print("Process", index, "finished training. Train time was:", time.time() - time_start)    
  return train_losses,val_losses

map function

def map_fn(index, flags):
  # Sets a common random seed - both for initialization and ensuring graph is the same
  torch.manual_seed(flags['seed'])

  # Acquires the (unique) Cloud TPU core corresponding to this process's index
  device = xm.xla_device()  


  image_size=224
  transform=transforms.Compose([
          transforms.Resize((image_size,image_size)),
          transforms.RandomHorizontalFlip(),
          transforms.ToTensor()
          # transforms.Normalize([0.485, 0.456, 0.406],
          #                      [0.229, 0.224, 0.225]
                              #  )  # Imagenet standards
      ])

  train_path=str('/content/gdrive/My Drive/Dataset_Grey_RGB/images_1/train')
  test_path=str('/content/gdrive/My Drive/Dataset_Grey_RGB/images_1/val')


  # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
  #                                 std=[0.229, 0.224, 0.225])
  # to_rgb = transforms.Lambda(lambda image: image.convert('RGB'))
  # resize = transforms.Resize((224, 224))
  # my_transform = transforms.Compose([resize, to_rgb, transforms.ToTensor(), normalize])

  # # Downloads train and test datasets
  # # Note: master goes first and downloads the dataset only once (xm.rendezvous)
  # #   all the other workers wait for the master to be done downloading.

  # if not xm.is_master_ordinal():
  #   xm.rendezvous('download_only_once')

  train_dataset = GrayscaleImageFolder(root=train_path,transform=transform,)

  test_dataset = GrayscaleImageFolder(root=test_path,transform=transform)
  
  if xm.is_master_ordinal():
    xm.rendezvous('init')
    # xm.rendezvous('download_only_once')
  
  # Creates the (distributed) train sampler, which let this process only access
  # its portion of the training dataset.
  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)
  
  test_sampler = torch.utils.data.distributed.DistributedSampler(
    test_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=False)
  
  # Creates dataloaders, which load data in batches
  # Note: test loader is not shuffled or sampled
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=flags['batch_size'],
      sampler=train_sampler,
      pin_memory=True,
      num_workers=flags['num_workers'],
      drop_last=True)

  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=flags['batch_size'],
      sampler=test_sampler,
      pin_memory=True,
      num_workers=flags['num_workers'],
      drop_last=True)
  
  ## Network, optimizer, and loss function creation|

  # Creates Autoencoer for 10 classes
  # Note: each process has its own identical copy of the model
  #  Even though each model is created independently, they're also
  #  created in the same way.
  # net = torchvision.models.alexnet(num_classes=10).to(device).train()
  net = AutoEncoder().to(device)
  criterion = torch.nn.MSELoss()
  optimizer = torch.optim.RMSprop(net.parameters(), lr=0.001, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False)


  train_losses,val_losses   = train(100,train_loader,test_loader,criterion,device)
  plt.plot(train_losses)
  plt.plot(val_losses)
  plt.show()

I am attaching code as well. Please let me know how can I rectify it or point me to the correct location where I can find its solution.

Code