Hi, I am training a UNET model for 3d data basically nifiti volumes with variable sizes in z dimension, the number of slices are different, they have their equivalent labels but I am getting this error on the first epoch, please let me know how to resolve it. I am sharing my code and error…
CODE
from monai.utils import first, set_determinism
import nibabel
from monai.transforms import (
AsDiscrete,
Activationsd,
EnsureType,
DivisiblePadd,
AsDiscreted,
EnsureChannelFirstd,
Compose,
CropForegroundd,
LoadImaged,
Orientationd,
NormalizeIntensityd,
Resized,
RandCropByPosNegLabeld,
SaveImaged,
ScaleIntensityRanged,
Spacingd,
Invertd,
)
from monai.handlers.utils import from_engine
from monai.handlers import TensorBoardStatsHandler
from monai.losses import DiceLoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.networks.nets import UNet
from monai.data import DataLoader
from monai.config import print_config
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import tempfile
import shutil
import os
import glob
data_dir = “/home/hpc/iwi5/iwi5131h”
root_dir = “/home/hpc/iwi5/iwi5131h”
train_images = sorted(glob.glob(os.path.join(data_dir, “traindata”, “.nii.gz")))
print(train_images)
train_labels = sorted(glob.glob(os.path.join(data_dir, “trainlabels”, ".nii.gz”)))
val_images = sorted(glob.glob(os.path.join(data_dir, ‘valdata’, ‘.nii.gz’)))
val_labels = sorted(glob.glob(os.path.join(data_dir, ‘vallabels’, '.nii.gz’)))
train_files = [{“image”: image_name, ‘label’: label_name} for image_name, label_name in zip(train_images, train_labels)]
val_files = [{“image”: image_name, ‘label’: label_name} for image_name, label_name in zip(val_images, val_labels)]
train_transforms = Compose(
[
LoadImaged(keys=[“image”, “label”]),
EnsureChannelFirstd(keys=[“image”, “label”]),
ScaleIntensityRanged(
keys=[“image”],
a_min=-57,
a_max=164,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=[“image”, “label”], source_key=“image”),
Orientationd(keys=[“image”, “label”], axcodes=“RAS”),
Spacingd(keys=[“image”, “label”], pixdim=(1.0, 1.0, 2.0), mode=(“bilinear”, “nearest”)),
RandCropByPosNegLabeld(
keys=[“image”, “label”],
label_key=“label”,
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key=“image”,
image_threshold=0,
),
# user can also add other random transforms
# RandAffined(
# keys=[‘image’, ‘label’],
# mode=(‘bilinear’, ‘nearest’),
# prob=1.0, spatial_size=(96, 96, 96),
# rotate_range=(0, 0, np.pi/15),
# scale_range=(0.1, 0.1, 0.1)),
]
)
val_transforms = Compose(
[
LoadImaged(keys=[“image”, “label”]),
EnsureChannelFirstd(keys=[“image”, “label”]),
ScaleIntensityRanged(
keys=[“image”],
a_min=-57,
a_max=164,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=[“image”, “label”], source_key=“image”),
Orientationd(keys=[“image”, “label”], axcodes=“RAS”),
Spacingd(keys=[“image”, “label”], pixdim=(1.0, 1.0, 2.0), mode=(“bilinear”, “nearest”)),
]
)
train_ds = CacheDataset(
data=train_files, transform=train_transforms,
cache_rate=1.0, num_workers=0)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=0)
val_ds = CacheDataset(
data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)
device = torch.device(“cuda:0”)
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=13, # 11 organs + background
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2), # lenght - 1
num_res_units=0,
norm=Norm.BATCH,
).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction=“mean”)
dice_metric_batch = DiceMetric(include_background=False, reduction=“mean_batch”)
surf_metric = HausdorffDistanceMetric(include_background=False, distance_metric=‘euclidean’, reduction=“mean”,
percentile=95)
surf_metric_batch = HausdorffDistanceMetric(include_background=True, distance_metric=‘euclidean’,
reduction=“mean_batch”, percentile=95)
max_epochs = 1000
val_interval = 1
best_metric = -1
best_metrics = -1
best_metric_epoch = -1
best_metric_epochs = -1
epoch_loss_values = []
metric_values = []
smetric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=13)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=13)])
for epoch in range(max_epochs):
print(“-” * 10)
print(f"epoch {epoch + 1}/{max_epochs}“)
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = (
batch_data[“image”].to(device),
batch_data[“label”].to(device),
)
optimizer.zero_grad()
# print(“Size of inputs[1] :”, inputs[1].shape)
# print(“printing of inputs :”, inputs)
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(
f”{step}/{len(train_ds) // train_loader.batch_size}, "
f"train_loss: {loss.item():.4f}“)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}”)
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
for val_data in val_loader:
val_inputs, val_labels = (
val_data["image"].to(device),
val_data["label"].to(device),
)
roi_size = (96, 96, 96)
sw_batch_size = 4
val_outputs = sliding_window_inference(
val_inputs, roi_size, sw_batch_size, model)
# val_outputs = model(val_inputs)
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
val_labels = [post_label(i) for i in decollate_batch(val_labels)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
dice_metric_batch(y_pred=val_outputs, y=val_labels)
surf_metric(y_pred=val_outputs, y=val_labels)
surf_metric_batch(y_pred=val_outputs, y=val_labels)
# aggregate the final mean dice result
metric = dice_metric.aggregate().item()
metric_batch_org = dice_metric_batch.aggregate()
smetric = surf_metric.aggregate().item()
smetric_batch_org = surf_metric_batch.aggregate()
print(metric_batch_org[0].item())
print(metric_batch_org[1].item())
print(metric_batch_org[2].item())
print(metric_batch_org[3].item())
print(metric_batch_org[4].item())
print(metric_batch_org[5].item())
print(metric_batch_org[6].item())
print(metric_batch_org[7].item())
print(metric_batch_org[8].item())
print(metric_batch_org[9].item())
print(metric_batch_org[10].item())
print(metric_batch_org[11].item())
print(metric_batch_org[12].item())
print("Hausdorff Distances")
print(smetric_batch_org[0].item())
print(smetric_batch_org[1].item())
print(smetric_batch_org[2].item())
print(smetric_batch_org[3].item())
print(smetric_batch_org[4].item())
print(smetric_batch_org[5].item())
print(smetric_batch_org[6].item())
print(smetric_batch_org[7].item())
print(smetric_batch_org[8].item())
print(smetric_batch_org[9].item())
print(smetric_batch_org[10].item())
print(smetric_batch_org[11].item())
print(metric_batch_org[12].item())
# reset the status for next validation round
dice_metric.reset()
dice_metric_batch.reset()
surf_metric.reset()
surf_metric_batch.reset()
metric_values.append(metric)
smetric_values.append(smetric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(), os.path.join(
root_dir, "resunet_trained_010102.pth"))
print("saved new best metric model")
print(
f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
f"\nbest mean dice: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}"
)
print(
f"train completed, best_metric: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}")
import numpy as np
np.save(‘loss_resunet_trained_010102.npy’, epoch_loss_values) # save
np.save(‘DICE_resunet_trained_010102.npy’, metric_values) # save
np.save(‘Hausdorff_resunet_trained_010102.npy’, smetric_values) # save
ERROR
epoch 1/1000
Traceback (most recent call last):
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/transform.py”, line 102, in apply_transform
return apply_transform(transform, data, unpack_items)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/transform.py”, line 66, in apply_transform
return transform(parameters)
^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/croppad/dictionary.py”, line 861, in call
self.randomize(label, fg_indices, bg_indices, image)
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/croppad/dictionary.py”, line 852, in randomize
self.cropper.randomize(label=label, fg_indices=fg_indices, bg_indices=bg_indices, image=image)
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/croppad/array.py”, line 1060, in randomize
fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/utils.py”, line 320, in map_binary_to_indices
bg_indices = nonzero(img_flat & ~label_flat)
~~~~~^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/data/meta_tensor.py”, line 268, in torch_function
ret = super().torch_function(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/torch/_tensor.py”, line 1295, in torch_function
ret = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (142909200) must match the size of tensor b (92475392) at non-singleton dimension 0
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File “/home/hpc/iwi5/iwi5131h/code.py”, line 165, in
for batch_data in train_loader:
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/torch/utils/data/dataloader.py”, line 634, in next
data = self._next_data()
^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/torch/utils/data/dataloader.py”, line 678, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py”, line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py”, line 51, in
data = [self.dataset[idx] for idx in possibly_batched_index]
~~~~~~~~~~~~^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/data/dataset.py”, line 107, in getitem
return self._transform(index)
^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/data/dataset.py”, line 921, in _transform
data = apply_transform(_transform, data)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/hpc/iwi5/iwi5131h/.conda/envs/cuda_torch/lib/python3.11/site-packages/monai/transforms/transform.py”, line 129, in apply_transform
raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform <monai.transforms.croppad.dictionary.RandCropByPosNegLabeld object at 0x7f0fa5d4bb10>