Pytorch model not learning at all

I’m trying to make a 2DUnet model, that gets 3d inputs,
but the model seem to be not working at all.
here’s my train.py and diagnostic_data_loading.py
from what is displayed through the terminal(loss, dice coefficient) I think it should put out something, but I’m getting blank masks.

import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import nibabel as nib
import wandb
from evaluate import evaluate
from unet import UNet
from utils.diagnostic_data_loading import BasicDataset
from utils.dice_score import dice_loss
from sklearn.model_selection import train_test_split

# Subject-wise split function
def split_data_subjectwise(t1_files, mask_files, test_size=0.2, random_state=42):
    subject_ids = list(set([f.split('/')[-1][:3] for f in t1_files]))
    t1_files_dict = {f.split('/')[-1][:3]: f for f in t1_files}
    mask_files_dict = {f.split('/')[-1][:3]: f for f in mask_files}
    matched_subject_ids = [sid for sid in subject_ids if sid in t1_files_dict and sid in mask_files_dict]
    
    train_subjects, val_subjects = train_test_split(matched_subject_ids, test_size=test_size, random_state=random_state)
    train_t1 = [t1_files_dict[sid] for sid in train_subjects]
    val_t1 = [t1_files_dict[sid] for sid in val_subjects]
    train_mask = [mask_files_dict[sid] for sid in train_subjects]
    val_mask = [mask_files_dict[sid] for sid in val_subjects]

    print(f"Train T1 files: {len(train_t1)}, Train Mask files: {len(train_mask)}")
    print(f"Val T1 files: {len(val_t1)}, Val Mask files: {len(val_mask)}")

    return train_t1, train_mask, val_t1, val_mask

def train_model(
        model,
        device,
        epochs=1,
        batch_size=1,
        learning_rate=1e-5,
        val_percent=0.1,
        save_checkpoint=True,
        img_scale=0.5,
        amp=False,
        weight_decay=1e-8,
        gradient_clipping=1.0,
        train_loader=None,
        val_loader=None,
        dir_checkpoint=Path('./checkpoints')
):
    # Initialize wandb if it hasn't been initialized
    experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    if experiment is None:
        raise RuntimeError("Failed to initialize Weights and Biases (wandb). Ensure wandb.init() is called properly.")
    
    experiment.config.update(
        dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
             val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
    )
    print("Weights and Biases (wandb) initialized successfully.")

    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.BCEWithLogitsLoss() if model.n_classes == 1 else nn.CrossEntropyLoss()
    scaler = torch.amp.GradScaler(enabled=amp)
    
    if train_loader is None or val_loader is None:
        raise ValueError("train_loader or val_loader is None. Ensure DataLoaders are initialized correctly.")
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        with tqdm(total=len(train_loader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar:
            for batch_idx, batch in enumerate(train_loader):
                images = batch[0].to(device, dtype=torch.float32)  # 2D T1 slices
                true_masks = batch[1].to(device, dtype=torch.float32 if model.n_classes == 1 else torch.long)  # 2D masks
                
                if batch_idx == 0:
                    print(f"[Debug] Input batch shape: {images.shape}")  # Ensure images are 2D
                    print(f"[Debug] True mask shape: {true_masks.shape}")  # Ensure masks are 2D
                
                with torch.amp.autocast(device_type='cuda', enabled=amp):
                    masks_pred = model(images)  # Feed 2D data into the model
                    loss = criterion(masks_pred, true_masks.squeeze(1))
                    epoch_loss += loss.item()
                
                    if batch_idx == 0:
                        print(f"[Debug] Predicted mask shape: {masks_pred.shape}")
                        print(f"[Debug] Predicted mask (sample values): {masks_pred[0, 0, :5, :5].detach().cpu().numpy()}")
                
                optimizer.zero_grad()
                scaler.scale(loss).backward()
                if gradient_clipping:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
                scaler.step(optimizer)
                scaler.update()

                pbar.update()
                pbar.set_postfix(**{'loss (batch)': loss.item()})
        
        avg_loss = epoch_loss / len(train_loader)
        wandb.log({"Validation Dice Score": avg_loss, "epoch": epoch + 1})

        if save_checkpoint:
            checkpoint_path = dir_checkpoint / f'checkpoint_epoch{epoch + 1}.pth'
            dir_checkpoint.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), checkpoint_path)
            wandb.save(str(checkpoint_path))
            
    if wandb.run:
        wandb.finish()
    print("Training completed and Weights and Biases logging finished.")

def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
    parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5, help='Learning rate', dest='lr')
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--img-dir', type=str, required=True, help='Path to the image directory')
    parser.add_argument('--mask-dir', type=str, required=True, help='Path to the mask directory')
    parser.add_argument('--checkpoint-dir', type=str, required=True, help='Directory to save checkpoints')
    parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
    parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')
    parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
    parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
    parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
    return parser.parse_args()

if __name__ == '__main__':
    args = get_args()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = UNet(n_channels=1, n_classes=args.classes, bilinear=args.bilinear).to(device)
    if args.load:
        model.load_state_dict(torch.load(args.load, map_location=device))
        print(f"Model loaded from {args.load}")

    # Load the T1 and mask files from directories
    t1_files = list(Path(args.img_dir).glob("*_T1_re.nii.gz"))
    mask_files = list(Path(args.mask_dir).glob("*_M_re.nii.gz"))

    # Perform subject-wise split
    train_t1_files, train_mask_files, val_t1_files, val_mask_files = split_data_subjectwise(
        [str(f) for f in t1_files],
        [str(f) for f in mask_files]
    )

    # Create datasets and loaders
    train_dataset = BasicDataset(images_dir=train_t1_files, masks_dir=train_mask_files)
    val_dataset = BasicDataset(images_dir=val_t1_files, masks_dir=val_mask_files)
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)

    try:
        train_model(
            model=model,
            device=device,
            epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.lr,
            val_percent=args.val / 100,
            save_checkpoint=True,
            img_scale=args.scale,
            amp=args.amp,
            train_loader=train_loader,
            val_loader=val_loader,
            dir_checkpoint=Path(args.checkpoint_dir)
        )
    except torch.cuda.OutOfMemoryError:
        print("OutOfMemoryError detected! Clearing cache and attempting again with reduced settings.")
        torch.cuda.empty_cache()
        args.amp = True
        args.batch_size = max(1, args.batch_size // 2)
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

import logging
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import Dataset
import nibabel as nib

def load_image(filepath):
    """Load a .nii.gz file and return it as a numpy array."""
    img = nib.load(filepath)
    data = img.get_fdata()
    return data

class BasicDataset:
    def __init__(self, images_dir, masks_dir, transform=None):
        # Accept lists of paths directly
        if isinstance(images_dir, list) and isinstance(masks_dir, list):
            self.image_paths = [Path(p) for p in images_dir]
            self.mask_paths = [Path(p) for p in masks_dir]
        else:
            raise TypeError("images_dir and masks_dir should be lists of file paths.")
        
        if len(self.image_paths) != len(self.mask_paths):
            raise ValueError("The number of images and masks must be the same.")
        
        self.transform = transform
        self.slice_pairs = []
        
        # Populate slice pairs with sagittal image-mask slices (axis 0)
        for img_path, mask_path in zip(self.image_paths, self.mask_paths):
            image_data = load_image(img_path)
            mask_data = load_image(mask_path)
            
            # Verify that images are 3D
            if image_data.ndim == 3 and mask_data.ndim == 3:
                for slice_idx in range(image_data.shape[0]):  # Only sagittal slices (axis 0)
                    image_slice = image_data[slice_idx, :, :]
                    mask_slice = mask_data[slice_idx, :, :]
                    self.slice_pairs.append((image_slice, mask_slice))
            else:
                logging.warning(f"Image or mask at {img_path} is not 3D. Skipping.")
    
    def __len__(self):
        return len(self.slice_pairs)

    def __getitem__(self, idx):
        image_slice, mask_slice = self.slice_pairs[idx]
        
        # Convert slices to torch tensors
        image_tensor = torch.tensor(image_slice, dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        mask_tensor = torch.tensor(mask_slice, dtype=torch.float32).unsqueeze(0)    # Add channel dimension

        if self.transform:
            image_tensor = self.transform(image_tensor)
            mask_tensor = self.transform(mask_tensor)

        return image_tensor, mask_tensor