Randomcropbyposneglabel issues with dimensions.. the size of tensor a (118210248) must match the size of tensor b (101117952) at non-singleton dimension 0 monai randcropposneg error

Hi, I am training a UNET model for 3d data basically nifiti volumes with variable sizes in z dimension, the number of slices are different, they have their equivalent labels but I am getting this error on the first epoch, please let me know how to resolve it. I am sharing my code and error…

CODE
from monai.utils import first, set_determinism
import nibabel
from monai.transforms import (
AsDiscrete,
Activationsd,
EnsureType,
DivisiblePadd,
AsDiscreted,
EnsureChannelFirstd,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
NormalizeIntensityd,
Resized,
RandCropByPosNegLabeld,
SaveImaged,
ScaleIntensityRanged,
Spacingd,
Invertd,
)
from monai.handlers.utils import from_engine

from monai.handlers import TensorBoardStatsHandler
from monai.losses import DiceLoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.networks.nets import UNet
from monai.data import DataLoader
from monai.config import print_config

from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import tempfile
import shutil
import os
import glob

data_dir = “/home/hpc/iwi5/iwi5131h”
root_dir = “/home/hpc/iwi5/iwi5131h”

train_images = sorted(glob.glob(os.path.join(data_dir, “traindata”, “.nii.gz")))
print(train_images)
train_labels = sorted(glob.glob(os.path.join(data_dir, “trainlabels”, "
.nii.gz”)))

val_images = sorted(glob.glob(os.path.join(data_dir, ‘valdata’, ‘.nii.gz’)))
val_labels = sorted(glob.glob(os.path.join(data_dir, ‘vallabels’, '
.nii.gz’)))

train_files = [{“image”: image_name, ‘label’: label_name} for image_name, label_name in zip(train_images, train_labels)]
val_files = [{“image”: image_name, ‘label’: label_name} for image_name, label_name in zip(val_images, val_labels)]

train_transforms = Compose(
[
LoadImaged(keys=[“image”, “label”]),
EnsureChannelFirstd(keys=[“image”, “label”]),
ScaleIntensityRanged(
keys=[“image”],
a_min=-57,
a_max=164,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=[“image”, “label”], source_key=“image”),
Orientationd(keys=[“image”, “label”], axcodes=“RAS”),
Spacingd(keys=[“image”, “label”], pixdim=(1.0, 1.0, 2.0), mode=(“bilinear”, “nearest”)),
RandCropByPosNegLabeld(
keys=[“image”, “label”],
label_key=“label”,
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key=“image”,
image_threshold=0,
),
# user can also add other random transforms
# RandAffined(
# keys=[‘image’, ‘label’],
# mode=(‘bilinear’, ‘nearest’),
# prob=1.0, spatial_size=(96, 96, 96),
# rotate_range=(0, 0, np.pi/15),
# scale_range=(0.1, 0.1, 0.1)),
]
)
val_transforms = Compose(
[
LoadImaged(keys=[“image”, “label”]),
EnsureChannelFirstd(keys=[“image”, “label”]),
ScaleIntensityRanged(
keys=[“image”],
a_min=-57,
a_max=164,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=[“image”, “label”], source_key=“image”),
Orientationd(keys=[“image”, “label”], axcodes=“RAS”),
Spacingd(keys=[“image”, “label”], pixdim=(1.0, 1.0, 2.0), mode=(“bilinear”, “nearest”)),
]
)
train_ds = CacheDataset(
data=train_files, transform=train_transforms,
cache_rate=1.0, num_workers=0)

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)

val_ds = CacheDataset(
data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=0)

val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)

device = torch.device(“cuda:0”)
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=13, # 11 organs + background
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2), # lenght - 1
num_res_units=0,
norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

dice_metric = DiceMetric(include_background=False, reduction=“mean”)
dice_metric_batch = DiceMetric(include_background=False, reduction=“mean_batch”)

surf_metric = HausdorffDistanceMetric(include_background=False, distance_metric=‘euclidean’, reduction=“mean”,
percentile=95)
surf_metric_batch = HausdorffDistanceMetric(include_background=True, distance_metric=‘euclidean’,
reduction=“mean_batch”, percentile=95)

max_epochs = 1000
val_interval = 1
best_metric = -1
best_metrics = -1
best_metric_epoch = -1
best_metric_epochs = -1
epoch_loss_values = []
metric_values = []
smetric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=13)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=13)])

for epoch in range(max_epochs):
print(“-” * 10)
print(f"epoch {epoch + 1}/{max_epochs}“)
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = (
batch_data[“image”].to(device),
batch_data[“label”].to(device),
)
optimizer.zero_grad()
# print(“Size of inputs[1] :”, inputs[1].shape)
# print(“printing of inputs :”, inputs)
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(
f”{step}/{len(train_ds) // train_loader.batch_size}, "
f"train_loss: {loss.item():.4f}“)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}”)

if (epoch + 1) % val_interval == 0:
    model.eval()
    with torch.no_grad():
        for val_data in val_loader:
            val_inputs, val_labels = (
                val_data["image"].to(device),
                val_data["label"].to(device),
            )
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(
                val_inputs, roi_size, sw_batch_size, model)
            # val_outputs = model(val_inputs)
            val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
            val_labels = [post_label(i) for i in decollate_batch(val_labels)]
            # compute metric for current iteration
            dice_metric(y_pred=val_outputs, y=val_labels)
            dice_metric_batch(y_pred=val_outputs, y=val_labels)
            surf_metric(y_pred=val_outputs, y=val_labels)
            surf_metric_batch(y_pred=val_outputs, y=val_labels)

        # aggregate the final mean dice result
        metric = dice_metric.aggregate().item()
        metric_batch_org = dice_metric_batch.aggregate()
        smetric = surf_metric.aggregate().item()
        smetric_batch_org = surf_metric_batch.aggregate()

        print(metric_batch_org[0].item())
        print(metric_batch_org[1].item())
        print(metric_batch_org[2].item())
        print(metric_batch_org[3].item())
        print(metric_batch_org[4].item())
        print(metric_batch_org[5].item())
        print(metric_batch_org[6].item())
        print(metric_batch_org[7].item())
        print(metric_batch_org[8].item())
        print(metric_batch_org[9].item())
        print(metric_batch_org[10].item())
        print(metric_batch_org[11].item())
        print(metric_batch_org[12].item())

        print("Hausdorff Distances")

        print(smetric_batch_org[0].item())
        print(smetric_batch_org[1].item())
        print(smetric_batch_org[2].item())
        print(smetric_batch_org[3].item())
        print(smetric_batch_org[4].item())
        print(smetric_batch_org[5].item())
        print(smetric_batch_org[6].item())
        print(smetric_batch_org[7].item())
        print(smetric_batch_org[8].item())
        print(smetric_batch_org[9].item())
        print(smetric_batch_org[10].item())
        print(smetric_batch_org[11].item())
        print(metric_batch_org[12].item())

        # reset the status for next validation round
        dice_metric.reset()
        dice_metric_batch.reset()
        surf_metric.reset()
        surf_metric_batch.reset()

        metric_values.append(metric)
        smetric_values.append(smetric)

        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(model.state_dict(), os.path.join(
                root_dir, "resunet_trained_010102.pth"))
            print("saved new best metric model")
        print(
            f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
            f"\nbest mean dice: {best_metric:.4f} "
            f"at epoch: {best_metric_epoch}"
        )

print(
f"train completed, best_metric: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}")

import numpy as np

np.save(‘loss_resunet_trained_010102.npy’, epoch_loss_values) # save
np.save(‘DICE_resunet_trained_010102.npy’, metric_values) # save
np.save(‘Hausdorff_resunet_trained_010102.npy’, smetric_values) # save

ERROR
epoch 1/1000
Traceback (most recent call last):
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/transform.py”, line 102, in apply_transform
return apply_transform(transform, data, unpack_items)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/transform.py”, line 66, in apply_transform
return transform(parameters)
^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/croppad/dictionary.py”, line 861, in call
self.randomize(label, fg_indices, bg_indices, image)
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/croppad/dictionary.py”, line 852, in randomize
self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image)
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/croppad/array.py”, line 1060, in randomize
fg_indices
, bg_indices
= map_binary_to_indices(label, image, self.image_threshold)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/utils.py”, line 320, in map_binary_to_indices
bg_indices = nonzero(img_flat & ~label_flat)
~^~~~~
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/data/meta_tensor.py”, line 268, in torch_function
ret = super().torch_function(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/torch/_tensor.py”, line 1295, in torch_function
ret = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (142909200) must match the size of tensor b (92475392) at non-singleton dimension 0

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File “/home/hpc/iwi5/iwi5131h/code.py”, line 165, in
for batch_data in train_loader:
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/torch/utils/data/dataloader.py”, line 634, in next
data = self._next_data()
^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/torch/utils/data/dataloader.py”, line 678, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py”, line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py”, line 51, in
data = [self.dataset[idx] for idx in possibly_batched_index]
~~~~~~~~~~~~^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/data/dataset.py”, line 107, in getitem
return self._transform(index)
^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/data/dataset.py”, line 921, in _transform
data = apply_transform(_transform, data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/transform.py”, line 129, in apply_transform
raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform <monai.transforms.croppad.dictionary.RandCropByPosNegLabeld object at 0x7f0fa5d4bb10>