Test loss is constant

I am using a unet model for segmentation on custom dataset

# import the necessary packages
from . import config
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F
import torch
 
class Block(Module):
     def __init__(self, inChannels, outChannels):
         super().__init__()
         # store the convolution and RELU layers
         self.conv1 = Conv2d(inChannels, outChannels, 3)
         self.relu = ReLU()
         self.conv2 = Conv2d(outChannels, outChannels, 3)
     def forward(self, x):
         # apply CONV => RELU => CONV block to the inputs and return it
         return self.conv2(self.relu(self.conv1(x))) 
class Encoder(Module):
     def __init__(self, channels=(3, 16, 32, 64)):
         super().__init__()
         # store the encoder blocks and maxpooling layer
         self.encBlocks = ModuleList(
             [Block(channels[i], channels[i + 1])
                  for i in range(len(channels) - 1)])
         self.pool = MaxPool2d(2)
     def forward(self, x):
         # initialize an empty list to store the intermediate outputs
         blockOutputs = []
         # loop through the encoder blocks
         for block in self.encBlocks:
             # pass the inputs through the current encoder block, store
             # the outputs, and then apply maxpooling on the output
             x = block(x)
             blockOutputs.append(x)
             x = self.pool(x)
         # return the list containing the intermediate outputs
         return blockOutputs
class Decoder(Module):
     def __init__(self, channels=(64, 32, 16)):
         super().__init__()
         # initialize the number of channels, upsampler blocks, and
         # decoder blocks
         self.channels = channels
         self.upconvs = ModuleList(
             [ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
                  for i in range(len(channels) - 1)])
         self.dec_blocks = ModuleList(
             [Block(channels[i], channels[i + 1])
                  for i in range(len(channels) - 1)])
     def forward(self, x, encFeatures):
         # loop through the number of channels
         for i in range(len(self.channels) - 1):
             # pass the inputs through the upsampler blocks
             x = self.upconvs[i](x)
             # crop the current features from the encoder blocks,
             # concatenate them with the current upsampled features,
             # and pass the concatenated output through the current
             # decoder block
             encFeat = self.crop(encFeatures[i], x)
             x = torch.cat([x, encFeat], dim=1)
             x = self.dec_blocks[i](x)
         # return the final decoder output
         return x
         
     def crop(self, encFeatures, x):
         # grab the dimensions of the inputs, and crop the encoder
         # features to match the dimensions
         (_, _, H, W) = x.shape
         encFeatures = CenterCrop([H, W])(encFeatures)
         # return the cropped features
         return encFeatures
         
class UNet(Module):
     def __init__(self, encChannels=(3, 16, 32, 64),
          decChannels=(64, 32, 16),
          nbClasses=1, retainDim=True,
          outSize=(config.INPUT_IMAGE_HEIGHT,  config.INPUT_IMAGE_WIDTH)):
         super().__init__()
         # initialize the encoder and decoder
         self.encoder = Encoder(encChannels)
         self.decoder = Decoder(decChannels)
         # initialize the regression head and store the class variables
         self.head = Conv2d(decChannels[-1], nbClasses, 1)
         self.retainDim = retainDim
         self.outSize = outSize
         
     def forward(self, x):
         # grab the features from the encoder
         encFeatures = self.encoder(x)
         # pass the encoder features through decoder making sure that
         # their dimensions are suited for concatenation
         decFeatures = self.decoder(encFeatures[::-1][0],
             encFeatures[::-1][1:])
         # pass the decoder features through the regression head to
         # obtain the segmentation mask
         map = self.head(decFeatures)
         # check to see if we are retaining the original output
         # dimensions and if so, then resize the output to match them
         if self.retainDim:
             map = F.interpolate(map, self.outSize)
         # return the segmentation map
         return map
 

the test loss is constant and only changes slightly

[INFO] saving testing image paths...
[INFO] found 4353 examples in the training set...
[INFO] found 769 examples in the test set...
[INFO] training the network...
  0% 0/40 [00:00<?, ?it/s][INFO] EPOCH: 1/40
Train loss: 0.111058, Test loss: 0.0145
  2% 1/40 [14:40<9:32:05, 880.13s/it][INFO] EPOCH: 2/40
Train loss: 0.013129, Test loss: 0.0140
  5% 2/40 [15:20<4:04:26, 385.96s/it][INFO] EPOCH: 3/40
Train loss: 0.013229, Test loss: 0.0140
  8% 3/40 [16:00<2:20:37, 228.05s/it][INFO] EPOCH: 4/40
Train loss: 0.012924, Test loss: 0.0140
 10% 4/40 [16:40<1:32:15, 153.76s/it][INFO] EPOCH: 5/40
Train loss: 0.012981, Test loss: 0.0140
 12% 5/40 [17:20<1:05:44, 112.69s/it][INFO] EPOCH: 6/40
Train loss: 0.013169, Test loss: 0.0140
 15% 6/40 [17:59<49:49, 87.92s/it]   [INFO] EPOCH: 7/40
Train loss: 0.013158, Test loss: 0.0140
 18% 7/40 [18:39<39:41, 72.15s/it][INFO] EPOCH: 8/40
Train loss: 0.013200, Test loss: 0.0139
 20% 8/40 [19:19<32:59, 61.85s/it][INFO] EPOCH: 9/40
Train loss: 0.013065, Test loss: 0.0139
 22% 9/40 [19:59<28:24, 54.98s/it][INFO] EPOCH: 10/40
Train loss: 0.013046, Test loss: 0.0139
 25% 10/40 [20:39<25:10, 50.36s/it][INFO] EPOCH: 11/40
Train loss: 0.012856, Test loss: 0.0139
 28% 11/40 [21:18<22:45, 47.09s/it][INFO] EPOCH: 12/40
Train loss: 0.013120, Test loss: 0.0138
 30% 12/40 [21:58<20:56, 44.89s/it][INFO] EPOCH: 13/40
Train loss: 0.012986, Test loss: 0.0138
 32% 13/40 [22:38<19:30, 43.36s/it][INFO] EPOCH: 14/40
Train loss: 0.012794, Test loss: 0.0138
 35% 14/40 [23:17<18:14, 42.11s/it][INFO] EPOCH: 15/40
Train loss: 0.013119, Test loss: 0.0138
 38% 15/40 [23:57<17:13, 41.35s/it][INFO] EPOCH: 16/40
Train loss: 0.013101, Test loss: 0.0138
 40% 16/40 [24:37<16:19, 40.82s/it][INFO] EPOCH: 17/40
Train loss: 0.012831, Test loss: 0.0138
 42% 17/40 [25:17<15:37, 40.76s/it][INFO] EPOCH: 18/40
Train loss: 0.012973, Test loss: 0.0138
 45% 18/40 [25:57<14:51, 40.54s/it][INFO] EPOCH: 19/40
Train loss: 0.012753, Test loss: 0.0138
 48% 19/40 [26:37<14:06, 40.31s/it][INFO] EPOCH: 20/40
Train loss: 0.012880, Test loss: 0.0138
 50% 20/40 [27:16<13:18, 39.90s/it][INFO] EPOCH: 21/40
Train loss: 0.013107, Test loss: 0.0138

While the test loss seems to be almost static note that your training loss is also not decreasing and jumps around a value of ~0.013, which is the issue I would target first.
Try to overfit a small dataset (e.g. just 10 samples) and make sure your model is able to learn these samples perfectly. Once this is done, try to scale up the use case again.

with about 21 images the losses are inf

[INFO] saving testing image paths...
[INFO] found 17 examples in the training set...
[INFO] found 4 examples in the test set...
[INFO] training the network...
  0% 0/40 [00:00<?, ?it/s][INFO] EPOCH: 1/40
Train loss: inf, Test loss: inf
  2% 1/40 [00:23<15:22, 23.65s/it][INFO] EPOCH: 2/40
Train loss: inf, Test loss: inf
  5% 2/40 [00:24<06:19, 10.00s/it][INFO] EPOCH: 3/40
Train loss: inf, Test loss: inf
  8% 3/40 [00:24<03:28,  5.63s/it][INFO] EPOCH: 4/40
Train loss: inf, Test loss: inf
 10% 4/40 [00:24<02:09,  3.59s/it][INFO] EPOCH: 5/40
Train loss: inf, Test loss: inf
 12% 5/40 [00:25<01:25,  2.46s/it][INFO] EPOCH: 6/40
Train loss: inf, Test loss: inf
 15% 6/40 [00:25<01:00,  1.78s/it][INFO] EPOCH: 7/40
Train loss: inf, Test loss: inf
 18% 7/40 [00:26<00:44,  1.35s/it][INFO] EPOCH: 8/40
Train loss: inf, Test loss: inf
 20% 8/40 [00:26<00:34,  1.07s/it][INFO] EPOCH: 9/40
Train loss: inf, Test loss: inf
 22% 9/40 [00:27<00:27,  1.13it/s][INFO] EPOCH: 10/40
Train loss: inf, Test loss: inf
 25% 10/40 [00:27<00:22,  1.32it/s][INFO] EPOCH: 11/40
Train loss: inf, Test loss: inf
 28% 11/40 [00:28<00:19,  1.48it/s][INFO] EPOCH: 12/40
Train loss: inf, Test loss: inf
 30% 12/40 [00:28<00:17,  1.63it/s][INFO] EPOCH: 13/40
Train loss: inf, Test loss: inf
 32% 13/40 [00:29<00:15,  1.75it/s][INFO] EPOCH: 14/40
Train loss: inf, Test loss: inf
 35% 14/40 [00:29<00:13,  1.86it/s][INFO] EPOCH: 15/40
Train loss: inf, Test loss: inf
 38% 15/40 [00:30<00:13,  1.91it/s][INFO] EPOCH: 16/40
Train loss: inf, Test loss: inf
 40% 16/40 [00:30<00:12,  1.97it/s][INFO] EPOCH: 17/40
Train loss: inf, Test loss: inf
 42% 17/40 [00:31<00:11,  2.00it/s][INFO] EPOCH: 18/40
Train loss: inf, Test loss: inf
 45% 18/40 [00:31<00:10,  2.04it/s][INFO] EPOCH: 19/40
Train loss: inf, Test loss: inf
 48% 19/40 [00:32<00:10,  2.04it/s][INFO] EPOCH: 20/40
Train loss: inf, Test loss: inf
 50% 20/40 [00:32<00:09,  2.06it/s][INFO] EPOCH: 21/40
Train loss: inf, Test loss: inf
 52% 21/40 [00:33<00:09,  2.04it/s][INFO] EPOCH: 22/40
Train loss: inf, Test loss: inf
 55% 22/40 [00:33<00:08,  2.03it/s][INFO] EPOCH: 23/40
Train loss: inf, Test loss: inf
 57% 23/40 [00:34<00:08,  2.04it/s][INFO] EPOCH: 24/40
Train loss: inf, Test loss: inf
 60% 24/40 [00:34<00:07,  2.05it/s][INFO] EPOCH: 25/40
Train loss: inf, Test loss: inf
 62% 25/40 [00:35<00:07,  2.05it/s][INFO] EPOCH: 26/40
Train loss: inf, Test loss: inf
 65% 26/40 [00:35<00:06,  2.07it/s][INFO] EPOCH: 27/40
Train loss: inf, Test loss: inf
 68% 27/40 [00:35<00:06,  2.07it/s][INFO] EPOCH: 28/40
Train loss: inf, Test loss: inf
 70% 28/40 [00:36<00:05,  2.08it/s][INFO] EPOCH: 29/40
Train loss: inf, Test loss: inf
 72% 29/40 [00:36<00:05,  2.10it/s][INFO] EPOCH: 30/40
Train loss: inf, Test loss: inf
 75% 30/40 [00:37<00:04,  2.10it/s][INFO] EPOCH: 31/40
Train loss: inf, Test loss: inf
 78% 31/40 [00:37<00:04,  2.11it/s][INFO] EPOCH: 32/40
Train loss: inf, Test loss: inf
 80% 32/40 [00:38<00:03,  2.10it/s][INFO] EPOCH: 33/40
Train loss: inf, Test loss: inf
 82% 33/40 [00:38<00:03,  2.11it/s][INFO] EPOCH: 34/40
Train loss: inf, Test loss: inf
 85% 34/40 [00:39<00:02,  2.09it/s][INFO] EPOCH: 35/40
Train loss: inf, Test loss: inf
 88% 35/40 [00:39<00:02,  2.09it/s][INFO] EPOCH: 36/40
Train loss: inf, Test loss: inf
 90% 36/40 [00:40<00:01,  2.10it/s][INFO] EPOCH: 37/40
Train loss: inf, Test loss: inf
 92% 37/40 [00:40<00:01,  2.09it/s][INFO] EPOCH: 38/40
Train loss: inf, Test loss: inf
 95% 38/40 [00:41<00:00,  2.08it/s][INFO] EPOCH: 39/40
Train loss: inf, Test loss: inf
 98% 39/40 [00:41<00:00,  2.09it/s][INFO] EPOCH: 40/40
Train loss: inf, Test loss: inf
100% 40/40 [00:42<00:00,  1.05s/it]
[INFO] total time taken to train the model: 42.18s

might this be a problem with backprop, here is my training script

from pyimagesearch.dataset import SegmentationDataset
from pyimagesearch.model import UNet
from pyimagesearch import config
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from imutils import paths
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import time
import os

# load the image and mask filepaths in a sorted manner
imagePaths = sorted(list(paths.list_images(config.IMAGE_DATASET_PATH)))
maskPaths = sorted(list(paths.list_images(config.MASK_DATASET_PATH)))
# partition the data into training and testing splits using 85% of
# the data for training and the remaining 15% for testing
split = train_test_split(imagePaths, maskPaths,
    test_size=config.TEST_SPLIT, random_state=42)
# unpack the data split
(trainImages, testImages) = split[:2]
(trainMasks, testMasks) = split[2:]
# write the testing image paths to disk so that we can use then
# when evaluating/testing our model
print("[INFO] saving testing image paths...")
f = open(config.TEST_PATHS, "w")
f.write("\n".join(testImages))
f.close()

# define transformations
transforms = transforms.Compose([transforms.ToPILImage(),
     transforms.Resize((config.INPUT_IMAGE_HEIGHT,
        config.INPUT_IMAGE_WIDTH)),
    transforms.ToTensor()])
# create the train and test datasets
trainDS = SegmentationDataset(imagePaths=trainImages, maskPaths=trainMasks,
    transforms=transforms)
testDS = SegmentationDataset(imagePaths=testImages, maskPaths=testMasks,
    transforms=transforms)
print(f"[INFO] found {len(trainDS)} examples in the training set...")
print(f"[INFO] found {len(testDS)} examples in the test set...")
# create the training and test data loaders
trainLoader = DataLoader(trainDS, shuffle=True,
    batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,
    num_workers=os.cpu_count())
testLoader = DataLoader(testDS, shuffle=False,
    batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,
    num_workers=os.cpu_count())
    
# initialize our UNet model
unet = UNet().to(config.DEVICE)
# initialize loss function and optimizer
lossFunc = BCEWithLogitsLoss()
opt = Adam(unet.parameters(), lr=config.INIT_LR)
# calculate steps per epoch for training and test set
trainSteps = len(trainDS) // config.BATCH_SIZE
testSteps = len(testDS) // config.BATCH_SIZE
# initialize a dictionary to store training history
H = {"train_loss": [], "test_loss": []}

# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.NUM_EPOCHS)):
    # set the model in training mode
    unet.train()
    # initialize the total training and validation loss
    totalTrainLoss = 0
    totalTestLoss = 0
    # loop over the training set
    for (i, (x, y)) in enumerate(trainLoader):
        # send the input to the device
        (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
        # perform a forward pass and calculate the training loss
        pred = unet(x)
        loss = lossFunc(pred, y)
        # first, zero out any previously accumulated gradients, then
        # perform backpropagation, and then update model parameters
        opt.zero_grad()
        loss.backward()
        opt.step()
        # add the loss to the total training loss so far
        totalTrainLoss += loss
    # switch off autograd
    with torch.no_grad():
        # set the model in evaluation mode
        unet.eval()
        # loop over the validation set
        for (x, y) in testLoader:
            # send the input to the device
            (x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
            # make the predictions and calculate the validation loss
            pred = unet(x)
            totalTestLoss += lossFunc(pred, y)
    # calculate the average training and validation loss
    avgTrainLoss = totalTrainLoss / trainSteps
    avgTestLoss = totalTestLoss / testSteps
    # update our training history
    H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
    H["test_loss"].append(avgTestLoss.cpu().detach().numpy())
    # print the model training and validation information
    print("[INFO] EPOCH: {}/{}".format(e + 1, config.NUM_EPOCHS))
    print("Train loss: {:.6f}, Test loss: {:.4f}".format(
        avgTrainLoss, avgTestLoss))
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
    endTime - startTime))
    
# plot the training loss
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(config.PLOT_PATH)
# serialize the model to disk
torch.save(unet, config.MODEL_PATH)

Since the tiny dataset apparently blows up the training, you should check what exactly the model predicts and if the loss calculation is valid.
I assume your last layer returns raw logits or did you apply an activation output on the outputs?