Transformer UNET will not train even with diverse data and augmentations

Hi everyone. I am building a UNETR, which is a Transformer based UNET for medical image vertebrae instance segmentation with MONAI. I am using the VerSe dataset, which is high quality and diverse CT images with labels of the human spine.

Here is my code where I perform data augmentation on the images. I use scaleintensityranged() to threshold only the brightest bone structures (like the spine)

train_transforms = Compose(
[
LoadImaged(keys=[“image”, “label”]),
EnsureChannelFirstd(keys=[“image”, “label”]),
Orientationd(keys=[“image”, “label”], axcodes=“RAS”),
Spacingd(
keys=[“image”, “label”],
pixdim=(1.17, 1.17, 5.0),
mode=(“bilinear”, “nearest”),
),
ScaleIntensityRanged(
keys=[“image”],
a_min=200,
a_max=400,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=[“image”, “label”], source_key=“image”,allow_smaller=True),
RandCropByPosNegLabeld(
keys=[“image”, “label”],
label_key=“label”,
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=4,
image_key=“image”,
image_threshold=0,
),
RandFlipd(
keys=[“image”, “label”],
spatial_axis=[0],
prob=0.10,
),
RandFlipd(
keys=[“image”, “label”],
spatial_axis=[1],
prob=0.10,
),
RandFlipd(
keys=[“image”, “label”],
spatial_axis=[2],
prob=0.10,
),
RandRotate90d(
keys=[“image”, “label”],
prob=0.10,
max_k=3,
),
RandShiftIntensityd(
keys=[“image”],
offsets=0.10,
prob=0.50,
),
]
)
)
val_transforms = Compose(
[
LoadImaged(keys=[“image”, “label”]),
EnsureChannelFirstd(keys=[“image”, “label”]),
Orientationd(keys=[“image”, “label”], axcodes=“RAS”),
Spacingd(
keys=[“image”, “label”],
pixdim=(1.17, 1.17, 5.0),
mode=(“bilinear”, “nearest”),
),
ScaleIntensityRanged(keys=[“image”], a_min=200, a_max=400, b_min=0.0, b_max=1.0, clip=True),
CropForegroundd(keys=[“image”, “label”], source_key=“image”,allow_smaller=True),
]
)

This is how I prepared the data. As you can see, the batch size is 2.

train_ds = CacheDataset(
data=train_files,
transform=train_transforms,
cache_num=24,
cache_rate=1.0,
num_workers=8,
)
batch_size = 2
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

Below is my model, loss function, optimizer, and training code:

os.environ[“CUDA_DEVICE_ORDER”] = “PCI_BUS_ID”
CUDA_LAUNCH_BLOCKING=1
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision(‘medium’)
no_classes = 29
model = UNETR(
in_channels=1,
out_channels=no_classes,
img_size=(96, 96, 96),
feature_size=16,
hidden_size=768,
mlp_dim=3072,
num_heads=12,
proj_type=“perceptron”,
norm_name=“batch”,
res_block=True,
conv_block=True,
dropout_rate=0,
).to(device)

lr = 1e-3
max_iterations = 6000
eval_num = 100
post_label = AsDiscrete(to_onehot=no_classes)
post_pred = AsDiscrete(argmax=True, to_onehot=no_classes)
dice_metric = DiceMetric(include_background=False, reduction=“mean”, get_not_nans=False)
loss_function = DiceLoss(to_onehot_y=True, softmax=True,include_background=False)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
scaler = torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,1000], gamma=0.1)

def validation(epoch_iterator_val):
model.eval()
with torch.no_grad():
for batch in epoch_iterator_val:
val_inputs, val_labels = (batch[“image”].to(device), batch[“label”].to(device))
with torch.cuda.amp.autocast():
val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
val_labels_list = decollate_batch(val_labels)
val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
val_outputs_list = decollate_batch(val_outputs)
val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
dice_metric(y_pred=val_output_convert, y=val_labels_convert)
epoch_iterator_val.set_description(“Validate (%d / %d Steps)” % (global_step, 10.0)) # noqa: B038
mean_dice_val = dice_metric.aggregate().item()
dice_metric.reset()
return mean_dice_val

def train(global_step, train_loader, dice_val_best, global_step_best):
model.train()
epoch_loss = 0
step = 0
epoch_iterator = tqdm(train_loader, desc=“Training (X / X Steps) (loss=X.X)”, dynamic_ncols=True)
for step, batch in enumerate(epoch_iterator):
step += 1
x, y = (batch[“image”].to(device), batch[“label”].to(device))
with torch.cuda.amp.autocast():
logit_map = model(x)
loss = loss_function(logit_map, y)
scaler.scale(loss).backward()
epoch_loss += loss.item()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scale = scaler.get_scale()
scaler.update()
# if not (scale > scaler.get_scale()):
scheduler.step()
optimizer.zero_grad()
epoch_iterator.set_description( # noqa: B038
“Training (%d / %d Steps) (loss=%2.5f)” % (global_step, max_iterations, loss)
)
if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
epoch_iterator_val = tqdm(val_loader, desc=“Validate (X / X Steps) (dice=X.X)”, dynamic_ncols=True)
dice_val = validation(epoch_iterator_val)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
metric_values.append(dice_val)
if dice_val > dice_val_best:
dice_val_best = dice_val
global_step_best = global_step
torch.save({
‘model_state_dict’: model.state_dict(),
‘optim_state_dict’: optimizer.state_dict(),
‘loss’: loss,
},os.path.join(root_dir, “best_metric_model.ckpt”))
print(
“Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}”.format(dice_val_best, dice_val)
)
else:
print(
“Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}”.format(
dice_val_best, dice_val
)
)
global_step += 1
print("Current Learning Rate: ", scheduler.get_last_lr()[0])
print("Epoch Loss: ", epoch_loss)
return global_step, dice_val_best, global_step_best

Finally, here is an example of the output of training. Even after 8+ hours, the model stays exactly the same:


Any help would be so so appreciated. I wanted to post more images of the data and augmentation for extra clarity but it wouldn’t let me!

What is your batch size , keeping a very low batch size is bad.
It will not let your model learn effectively.
Try to keep a minimum batch size of 8.

If you could share your model architecture , it would be easier to find the issue.

Hi Aniruth,
I edited my post to add how I prepared my date using the Cache Dataloader.
My training batch size is 2. I experimented quite a bit with larger batch sizes but 2 was the largest I could get without my GPU giving me memory errors. I have an RTX 4070 with 12 GB of VRAM, but since the UNETR is a 3D model, it uses a ton of memory.

Did you try converting your image into patches and train your model with the patches instead of a single high resolution image. That helps for high resolution images.

In that case you can increase your batch size too