ValueError: need at least one array to concatenate (TPU error))

Hi,
I am trying to use Pytorch XLA for an Image classification problem but when I try to run the code, after 1 epoch I get this error - ValueError: need at least one array to concatenate.
On GPU the code is running fine, the problem persists in TPU(using 1 core only).
Here is my validation function which is throwing the error:

    model.to(device)
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []
    
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds = model(imgs)   #output = model(input)
        #print(image_preds.shape, exam_pred.shape)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]
        
        loss = loss_fn(image_preds, image_labels)
        
        loss_sum += loss.item()*image_labels.shape[0]
        sample_num += image_labels.shape[0]  

        if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)):
            description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
            pbar.set_description(description)
    
    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    print('validation multi-class accuracy = {:.4f}'.format((image_preds_all==image_targets_all).mean()))
    
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()

And here is the complete stacktrace:

ValueError                                Traceback (most recent call last)
<ipython-input-18-bd145273442c> in <module>
     43                 xm.save(model.state_dict(),'{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch))
     44 
---> 45         xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=CFG["n_procs"], start_method='fork')
     46 
     47 

/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py in spawn(fn, args, nprocs, join, daemon, start_method)
    384   pf_cfg = _pre_fork_setup(nprocs)
    385   if pf_cfg.num_devices == 1:
--> 386     _start_fn(0, pf_cfg, fn, args)
    387   else:
    388     return torch.multiprocessing.start_processes(

/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py in _start_fn(index, pf_cfg, fn, args)
    321   # environment must be fully setup before doing so.
    322   _setup_replication()
--> 323   fn(gindex, *args)
    324 
    325 

<ipython-input-18-bd145273442c> in _mp_fn(rank, flags)
     38 
     39                 with torch.no_grad():
---> 40                     valid_one_epoch(epoch, model, loss_fn, val_parallel, device, scheduler=None, schd_loss_update=False)
     41 
     42                 torch.cuda.empty_cache()

<ipython-input-16-83200ba39886> in valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler, schd_loss_update)
    105             pbar.set_description(description)
    106 
--> 107     image_preds_all = np.concatenate(image_preds_all)
    108     image_targets_all = np.concatenate(image_targets_all)
    109     print('validation multi-class accuracy = {:.4f}'.format((image_preds_all==image_targets_all).mean()))

<__array_function__ internals> in concatenate(*args, **kwargs)

ValueError: need at least one array to concatenate