Not sure how to resolve this error, trying to use a TPU on Google Colab to train a neural network. Looking for help please!

Here is the code, I have highlighted in bold where the error is occurring:

import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm #for progress bar during training
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
import pickle
import os
import time
import torch_xla
import torch_xla.core.xla_model as xm

#from model import UNET
#from utils import (
#load_checkpoint,
#save_checkpoint,
#get_loaders,
#check_accuracy,
#save_predictions_as_imgs,
#) WE ARE NOT IMPORTING THESE AS USING GOOGLE COLAB SO ALL IN SAME SCRIPT

Hyperparameters etc.

slice1 = 20
slice2 = 28
slice = 22
DATE_TIME = datetime.now()
LEARNING_RATE = 0.0001
#DEVICE = “cuda” if torch.cuda.is_available() else “cpu” # device agnostic code
DEVICE = xm.xla_device()
BATCH_SIZE = 10
NUM_EPOCHS = 1
NUM_WORKERS = 5
IMAGE_HEIGHT = 376 # images originally 576, masks originally 640
IMAGE_WIDTH = 376 # images originally 576, masks originally 640
PIN_MEMORY = True
LOAD_MODEL = False
MAIN_DIR = ‘/content/drive/MyDrive/Oxford/Year_4/4YP/Data/2D_data/2D_data_sliced_22’
TRAIN_IMG_DIR = os.path.join(MAIN_DIR, ‘train/images’)
TRAIN_MASK_DIR = os.path.join(MAIN_DIR, ‘train/masks’)
VAL_IMG_DIR = os.path.join(MAIN_DIR, ‘val/images’)
VAL_MASK_DIR = os.path.join(MAIN_DIR, ‘val/masks’)
VAL_PREDS_DIR = os.path.join(MAIN_DIR, ‘predictions’)

LOAD_MODEL_DIR = os.path.join(MAIN_DIR, ‘trained_model(B: 10, LR: 1e-05, E: 20)’)

TRAINING_RESULTS_DIR = os.path.join(MAIN_DIR, f’training_loss/training_results(B: {BATCH_SIZE}, LR: {LEARNING_RATE}, E: {NUM_EPOCHS}).pkl’)

SAVE_MODEL_DIR = os.path.join(MAIN_DIR, f’trained_model(B: {BATCH_SIZE}, LR: {LEARNING_RATE}, E: {NUM_EPOCHS})')

Training Function

def train_fn(loader, model, optimizer, loss_fn, scaler):
loop = tqdm(loader)

# Setup train loss variable
train_loss = 0

for batch_idx, (data, targets) in enumerate(loop):
    #print(f'Data Shape before: {data.shape}')
    #print(f'Target Shape before: {targets.shape}')
    data = data.type(torch.float32).to(device=DEVICE)
    targets = targets.type(torch.float32).unsqueeze(1).to(device=DEVICE)
    ##data = data.to(device=DEVICE) #putting data + targets onto device we are using
    ##targets = targets.float().unsqueeze(1).to(device=DEVICE) # unsqueeze(1) to add a channel dimension
    #print(f'Data Shape after: {data.shape}')
    #print(f'Target Shape after: {targets.shape}')

    
    with torch.cuda.amp.autocast():
        # forward pass
        predictions = model(data)
        predictions = predictions.type(torch.float32)
        # calculate the loss
        print(f'Data Type of Targets: {targets.dtype}')
        print(f'Data Type of Preds: {predictions.dtype}')
        loss = loss_fn(predictions, targets)
        print(f'Data Type of Loss: {loss.dtype}')
        #print(f'Data Type of Loss.item(): {loss.item().dtype}')
        #loss = loss.type(torch.float32)
        #train_loss += loss.item()
        train_loss += loss

    # backward
    optimizer.zero_grad()
    ##scaler.scale(loss).backward()
    loss.backward()
    **xm.optimizer_step(optimizer, barrier=True)**
    ##scaler.step(optimizer)
    ##scaler.update()
    


    # update tqdm loop
    loop.set_postfix(loss=loss.item())

train_loss = train_loss/len(loader)

return train_loss

Validation Testing Function

def val_fn(loader, model, loss_fn):
loop = tqdm(loader)

# Setup val loss variable
val_loss = 0

for batch_idx, (data, targets) in enumerate(loop):
    data = data.to(device=DEVICE) #putting data + targets onto device we are using
    targets = targets.float().unsqueeze(1).to(device=DEVICE) # unsqueeze(1) to add a channel dimension

    
    with torch.cuda.amp.autocast():
        # forward pass
        predictions = model(data)
        # calculate the loss
        loss = loss_fn(predictions, targets)
        val_loss += loss.item()

    # update tqdm loop
    loop.set_postfix(loss=loss.item())

val_loss = val_loss/len(loader)

return val_loss

def main():
train_transform = A.Compose(
[
#A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
A.Rotate(limit=35, p=1.0),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.1),
A.Normalize(
mean=[0.0], #, 0.0, 0.0],
std=[1.0], #, 1.0, 1.0],
max_pixel_value=255.0,
),
ToTensorV2(),
],
)

val_transforms = A.Compose(
    [
        #A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0], # 0.0, 0.0],
            std=[1.0], # 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)

# Creating an instance of our UNET Model, calling a loss_fn and optimizer
model = UNET(in_channels=1, out_channels=1).to(DEVICE) # if we were doing mutliclass classification, change out_channels to number of classes and change loss_fn to CrossEntropyLoss
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loader, val_loader = get_loaders(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    BATCH_SIZE,
    train_transform,
    val_transforms,
    NUM_WORKERS,
    PIN_MEMORY,
)

if LOAD_MODEL:
    load_checkpoint(torch.load(LOAD_MODEL_DIR), model)


# checking accuracy of loaded model - 80% explanation below
initial_dice_score = check_accuracy(val_loader, model, device=DEVICE)
print(f'Initial DCS: {initial_dice_score:.3f}')
scaler = torch.cuda.amp.GradScaler()



#Create empty results dictionary to store training and validation loss if we are not loading a new model
results = {'train_loss': [],
          'val_loss': [],
          'training_dice': [],
          'val_dice': []}

epoch_number = 1

start_time = time.time()

for epoch in range(NUM_EPOCHS):
    
    print(f'Epoch Number: {epoch_number}')

    train_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler)

    val_loss = val_fn(val_loader, model, loss_fn)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer":optimizer.state_dict(),
    }
    save_checkpoint(checkpoint, SAVE_MODEL_DIR)

    # check accuracy
    train_dice_score = check_accuracy(train_loader, model, device=DEVICE)
    val_dice_score = check_accuracy(val_loader, model, device=DEVICE)

    

    #store training loss
    results['train_loss'].append(train_loss)
    results['val_loss'].append(val_loss)
    results['training_dice'].append(train_dice_score.item())
    results['val_dice'].append(val_dice_score.item())

    #with open(TRAINING_LOSS_DIR, 'w') as f:
      #f.write('Training Loss:\n')
      #f.write(str(results['train_loss']))
      #f.write(' \n')
      #f.write('Validation Loss:\n')
      #f.write(str(results['val_loss']))
      #f.write('Accuracy:\n')
      #f.write(str(results['accuracy']))
      #f.write('Dice Score:\n')
      #f.write(str(results['dice_score']))
    
    print(f'Training Loss: {train_loss:.3f}, Validation Loss: {val_loss:.3f}')
    print(f'Training Dice: {train_dice_score.item():.3f}, Validation Dice: {val_dice_score.item():.3f}')

    
    epoch_number += 1

    

    # print some examples to a folder
#save_predictions_as_imgs(
 #       val_loader, model, batch_size=BATCH_SIZE, folder=VAL_PREDS_DIR, device=DEVICE
  #  )

elapsed_time = time.time() - start_time
print(f'Time taken to train: {elapsed_time:.1f} seconds')

with open(TRAINING_RESULTS_DIR, 'wb') as fp:
      pickle.dump(results, fp)
      print('dictionary saved successfully to file')
  
return results, model

if name == “main”:
results, model = main()

Here is the error message:

RuntimeError Traceback (most recent call last)
in <cell line: 254>()
253
254 if name == “main”:
→ 255 results, model = main()
256

3 frames
/usr/local/lib/python3.9/dist-packages/torch_xla/core/xla_model.py in mark_step(wait)
947 file=sys.stderr,
948 flush=True)
→ 949 torch_xla._XLAC._xla_step_marker(
950 torch_xla._XLAC._xla_get_default_device(), [],
951 wait=xu.getenv_as(‘XLA_SYNC_WAIT’, bool, wait))

RuntimeError: /pytorch/xla/torch_xla/csrc/xla_graph_executor.cpp:523 : Check failed: tensor_data
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::XLAGraphExecutor::CollectSyncTensors(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&)
torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >, absl::lts_20220623::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >
, absl::lts_20220623::Span<std::string const>, bool, bool, bool)
torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRefstd::string, bool)

_PyObject_MakeTpCall
_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

_PyEval_EvalCodeWithName
PyEval_EvalCode


_PyEval_EvalFrameDefault

_PyEval_EvalFrameDefault

_PyEval_EvalFrameDefault


_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault


PyObject_Call
_PyEval_EvalFrameDefault


_PyEval_EvalFrameDefault





_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault





_PyEval_EvalFrameDefault


_PyEval_EvalFrameDefault





_PyEval_EvalFrameDefault
_PyFunction_Vectorcall

PyObject_Call
_PyEval_EvalFrameDefault


_PyEval_EvalFrameDefault

_PyFunction_Vectorcall




_PyEval_EvalFrameDefault
_PyFunction_Vectorcall


_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall



PyObject_Call
_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault

_PyFunction_Vectorcall
_PyEval_EvalFrameDefault


_PyEval_EvalFrameDefault

_PyEval_EvalCodeWithName
PyEval_EvalCode


_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
PyObject_Call

Py_RunMain
Py_BytesMain

*** End stack trace ***