3D Unet Implementation doesn't overfit

This is the training code I’m using. I had to create my own Resize function since I’m dealign with TIFF images. I use F.interpolate for that.
I’ve been testing the network with only 1 image and I can’t overfit. I also implemented different models from the web and all of them behave in the same way, that’s why I think the problem is on my training, data loader or transformation code.
What do you think could be the problem?

import torch
import unet3D
import numpy as np
import transformations_3_D
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import visdom
import time
from sklearn.metrics import jaccard_score
from torch.functional import F
from dataset_3D import MacularHole
from torch.utils import data
# from images3D_visualizer import ImageVisualizer
from matplotlib import pyplot as plt
import unet3d_github

# Resize function

# Resize function
def labels_resize(tensor, size, normalize=False):
    tensor = F.interpolate(tensor, size=size, mode="nearest")
    if normalize:
        if torch.min(tensor) == 0 and torch.max(tensor) == 0:
            tensor = tensor
        else:
            tensor = (tensor - torch.min(tensor)) / (torch.max(tensor) - torch.min(tensor))

    return tensor


def prediction_resize(tensor, size):
    tensor = F.interpolate(tensor, size=size, mode="trilinear", align_corners=True)
    return tensor


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([transformations_3_D.Resize((116, 132, 132))])

    seed = np.random.randint(0, 1000)

    train_set = MacularHole("dataset_first_tests_45_imgs/", mode='train', seed=seed, transform=transform,
                            training_size=1, valid_size=9, testing_size=9)

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, num_workers=16)
  
    # model = unet3D.UNet()

    model = model.to(device)

    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    # Lists of data for visdom
    max_train_predictions = []
    train_performance_values = []
    max_valid_predictions = []
    valid_performance_values = []
    train_loss_values = []
    valid_loss_values = []
    test_performance_values = []

    epochs = 5000
    valid_performance_min = 0.0
    train_max_value = 0.0
    valid_max_value = 0.0
    t0 = time.time()
    for i in range(epochs):
        train_loss = 0.0
        valid_loss = 0.0
        jaccard_scores_train = []

        model.train()
        for batch, (images, labels) in enumerate(train_loader):
            images, labels = images.float(), labels.float()
            images, labels = images.to(device), labels.to(device)
            labels = labels_resize(labels, (49, 200, 200))

            optimizer.zero_grad()
            train_predictions = model(images)
            resize_labels = labels_resize(labels, (28, 44, 44), normalize=True)
            train_max_value = torch.max(train_predictions)
            loss = criterion(train_predictions, resize_labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            # jaccard scores
            with torch.no_grad():

                train_predictions = prediction_resize(train_predictions, (49, 200, 200))
                # Binarizing labels
                labels = labels != 0
                # Thresholding predictions
                binary_train_predictions = train_predictions > 0.5
                labels = labels.detach().cpu().numpy().astype(int)
                binary_train_predictions = binary_train_predictions.detach().cpu().numpy().astype(int)
                # Flattening numoy arrays for jaccard_score and adding each value to a list
                jaccard_scores_train += [jaccard_score(labels.flatten(), binary_train_predictions.flatten())]

        # calculating loss and performance in each epoch
        train_loss = train_loss / len(train_set)
        train_loss_values.append(train_loss)
        train_performance = np.mean(jaccard_scores_train)
        train_performance_values.append(train_performance)

        print(f'Epoch: {i + 1} |' f'\tTrain loss: {train_loss} |' f'\t| Train performance: {train_performance} |' 
              f'\t| Max training: {train_max_value} |')



if __name__ == '__main__':
    # vis = visdom.Visdom()
    main()