Loss problem in net finetuning

Well, you are apparently using an older version of PyTorch, but this shouldn’t be the problem here I think.
However, you should upgrade to the latest stable release, since e.g. Variables and tensors were merged.
You can find the install instructions of the website.

Also, you don’t have to reconstruct the criterion in each run.
Move the cr_en_loss = nn.CrossEntropyLoss() above the for loop.
This shouldn’t be the problem either.

Is the loss increase exactly happening after one full epoch?

I’m stuck to the older version because of my work group, unfortunately I can’t upgrade to the new version right now.
The loss criterion is out the for loop in my version of the code, I put it there for easy-to-read purposes.

The loss drastical increase and accuracy decrease both happen during exactly one epoch, every time I run the experiment.
My guess is the net just use the pretrained weight during the first epoch (which are ok at their job), giving me good results. Starting from the second epoch the net fresh-starts without any weights, learning from zero.
This should explain why the accuracy drop so much (it’s a segmentation task, all images become almost completely black).

OK, something seems to be broken. Could you post the whole code?
As far as I can tell, the current code looks good.

If you cannot post the code due to your work policy, could you have a look at the norm of the gradients in the first and second epoch?

In a few moments I will post the whole code, no problem. I will comment some part to make it easier to read.

This is the full code:

import argparse
import logger
import time
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import transforms
from data import MyDataset
from segnet import SegNet
from torch.autograd import Variable

def train(epoch):
    model.train()

    # update learning rate
    exp_lr_scheduler.step()

    total_loss = 0
    total_accuracy = 0

    # iteration over the batches
    for batch_idx, (img, gt) in enumerate(train_loader):

        if use_cuda:
            img = img.cuda(async=True)
            gt = gt.cuda(async=True)

        input = Variable(img)
        target = Variable(gt)

        # initialize gradients
        optimizer.zero_grad()

        # predictions
        output = model(input)

        """
        output is (24, 2, 224, 224)
        target is (24, 1, 224, 224)
        Here I change target.view() and type in order to use nn.CrossEntropyLoss()
        """
        
        tb = target.size(0)
        tc = target.size(1)
        th = target.size(2)
        tw = target.size(3)
        target_long = target.view(tb, th, tw).long()

        loss = cren_loss(output.cuda(), target_long.cuda())
        loss.backward()
        optimizer.step()

        """
        This is a segmentation task, so in the next part I compute how many 1 pixels are correctly classificated
        as 1 and how many 0 pixels are correctly 0. Then I simply calculate the mean of foreground and background
        accuracy.
        """
        
        output_pred = softmax(output)
        _, prediction = output_pred.max(dim=1)
        prediction = prediction.unsqueeze(1)

        mat_zero2zero = ((prediction == 0) * (target == 0)).int()
        mat_one2one = ((prediction == 1) * (target == 1)).int()

        prediction_back = mat_zero2zero.sum().float()
        target_back = target.numel() - target.sum()

        prediction_fore = mat_one2one.sum().float()
        target_fore = target.sum()

        acc_back = prediction_back / target_back
        acc_fore = prediction_fore / target_fore
        accuracy = (acc_back + acc_fore) / 2

        # TensorBoard logging
        info = {'train-loss': loss.data[0],
                'train-accuracy': accuracy}

        for tag, value in info.items():
            log.scalar_summary(tag, value, batch_idx + 1)

        print('batch: %5s | loss: %.3f | acc_back: %.3f | acc_fore: %.3f | acc: %.3f |'
              % (str(batch_idx + 1) + '/' + str(len(train_loader)),
                 loss.data[0],
                 acc_back,
                 acc_fore,
                 accuracy),
              time.strftime("%H:%M:%S", time.gmtime(time.time())),
              'training')

        total_loss += loss.data[0]
        total_accuracy += accuracy

    return total_loss / len(train_loader), total_accuracy / len(train_loader)


# training settings
parser = argparse.ArgumentParser(description='PyTorch SegNet')
parser.add_argument('--epochs', type=int, default=10, help='train epochs') 
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate')
parser.add_argument('--momentum', type=float, default=0.5, help='SGD momentum')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

# cuda
use_cuda = torch.cuda.is_available()

input_nbr = 3
label_nbr = 2
img_size = 224

batch_size = 24
num_workers = 4

start_epoch = 0

softmax = torch.nn.Softmax(dim=1)

if use_cuda:
    cren_loss = nn.CrossEntropyLoss().cuda()
else:
    cren_loss = nn.CrossEntropyLoss()

# create SegNet model
model = SegNet(input_nbr, label_nbr)
model.load_from_filename('/path/to/pretrained/weights')

# convert to cuda if needed
if use_cuda:
    model.cuda()
    cudnn.benchmark = True
else:
    model.float()

# finetuning
ftparams = ['conv11d.weight', 'conv11d.bias']
for name, param in model.named_parameters():
    if name not in ftparams:
        param.requires_grad = False

# define the optimizer
optimizer = optim.SGD(model.conv11d.parameters(), lr=args.lr, momentum=args.momentum)
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# define data
root_dir_img = '/path/to/img/dir'
root_dir_gt = './path/to/gt/dir'

transform_train = transforms.Compose([
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomResizedCrop(img_size),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor()
])

train_dataset = MyDataset(root_dir_img, root_dir_gt, transform_train)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True
)

# Set the logger
log = logger.Logger('./logs')

for epoch in range(start_epoch, start_epoch + args.epochs):
    print('epoch: %5s' % str(epoch+1))

    # training
    train_loss, train_acc = train(epoch)
    print('\nepoch: %5s | loss: %.3f | acc: %.3f |'
          % (str(epoch + 1) + '/' + str(start_epoch + args.epochs),
             train_loss,
             train_acc),
          time.strftime("%H:%M:%S", time.gmtime(time.time())),
          'training')

    print('\n')

Thanks for the code. I am currently working on it creating some dummy data and targets.
One thing I’ve seen so far is the usage of transformation.
Since you are working on a segmentation task, I assume you have segmentation maps as the target.
I cannot see, how your Dataset is implemented, but if you are using some random transformations like RandomResizedCrop, and flipping, you have to take care of applying them also on your target.
Otherwise your input will be transformed and the model might have a hard time to learn the relationship between the input and target.

The easiest way would be to use the functional API of torchvision.
Here is a small example I created a while ago.

Let me know, if this helps!

The transformations are already applied both on images and ground truths where needed.

The dataset consist of some objects and their binary segmentation map.
I could provide you the code I’m using for dataset creation / transforms / net implementation if this could help.

Anyway everything seems to work fine during the first epoch, accuracy is high and loss low, since the pretrained weights are good. The problem is the passage from the first epoch to the second, my guess is some parameter are not handled correctly.

How can I check the norm of the gradients you were talking about?

Could you post the transformation part of your Dataset please?
Are you using the transform_train in it?

You can check if with model.conv11d.weight.grad.norm().

This is my Dataset class.

import os
import torch.utils.data
from PIL import Image
from PIL import ImageFile


class MyDataset(torch.utils.data.Dataset):

    def __init__(self, root_dir_img, root_dir_gt, transform=None):

        self.root_dir_img = root_dir_img
        self.root_dir_gt = root_dir_gt
        self.transform = transform

        img_names = [os.path.join(root_dir_img, name) for name in os.listdir(root_dir_img) if
                     os.path.isfile(os.path.join(root_dir_img, name))]

        gt_names = [os.path.join(root_dir_gt, name) for name in os.listdir(root_dir_gt) if
                    os.path.isfile(os.path.join(root_dir_gt, name))]

        self.img_files = []
        self.gt_files = []

        for i in range(len(img_names)):
            self.img_files.append(Image.open(img_names[i]))
            self.gt_files.append(Image.open(gt_names[i]))

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):

        ImageFile.LOAD_TRUNCATED_IMAGES = True

        img = self.img_files[idx]
        gt = self.gt_files[idx]

        sample = {'image': img, 'mask': gt}

        if self.transform:
            sample = self.transform(sample)
            img = sample['image']
            gt = sample['mask']

        return img, gt

I will check grad.norm() now.

epoch:  1/10 | loss: 0.499 | acc: 0.877 | 14:18:40 training
Variable containing:
 0.5379
[torch.cuda.FloatTensor of size 1 (GPU 0)]
epoch:  2/10 | loss: 4.012 | acc: 0.506 | 14:18:48 training
Variable containing:
 2.0424
[torch.cuda.FloatTensor of size 1 (GPU 0)]
epoch:  3/10 | loss: 4.082 | acc: 0.504 | 14:18:57 training
Variable containing:
 2.2331
[torch.cuda.FloatTensor of size 1 (GPU 0)]

These are the stats and norm of the gradients for the first three epochs.

Could you try to run your code with one or two images-mask pairs and see how your model is behaving then?
I still don’t see any obvious errors in your code, so we might have a look if the data is somehow corrupted/changed, even though you are not calling anything after the train() call, right?

I’m not calling anything after the train function.
If i try running the net it works fine, it does a good job at segmenting using the pretrained weights.
But the model obtained after finetuning is unusable (as shown by accuracy drop from 85% to 50%).
I noticed that if I let the training process run for many epochs (100+) I get a working model, basically trained from scratch. This does not solve my problem, but I guess is just another confirmation that the whole thing “is working”, but the parameters “get lost” moving from epoch 1 to epoch 2.

Yeah, I see the issue.
Could you remove the truncated images and try it again?
I still have the feeling the error is somehow related to the data.

EDIT: Also, could you remove the cuda() calls from this line:

loss = cren_loss(output.cuda(), target_long.cuda())

Probably you are on the right lead.

I removed this line:

ImageFile.LOAD_TRUNCATED_IMAGES = True

And I got this error:

Traceback (most recent call last):
 File "/.../train.py", line 192, in <module>
   train_loss, train_acc = train(epoch)
 File "/.../train.py", line 28, in train
   for batch_idx, (img, gt) in enumerate(train_loader):
 File "/.../venv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 281, in __next__
   return self._process_next_batch(batch)
 File "/...e/venv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 301, in _process_next_batch
   raise batch.exc_type(batch.exc_msg)
OSError: Traceback (most recent call last):
 File "/.../venv/lib/python3.6/site-packages/PIL/ImageFile.py", line 215, in load
   s = read(self.decodermaxblock)
 File "/.../venv/lib/python3.6/site-packages/PIL/PngImagePlugin.py", line 619, in load_read
   cid, pos, length = self.png.read()
 File "/.../venv/lib/python3.6/site-packages/PIL/PngImagePlugin.py", line 114, in read
   length = i32(s)
 File "/.../venv/lib/python3.6/site-packages/PIL/_binary.py", line 76, in i32be
   return unpack(">I", c[o:o+4])[0]
struct.error: unpack requires a buffer of 4 bytes

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
 File "/.../venv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 55, in _worker_loop
   samples = collate_fn([dataset[i] for i in batch_indices])
 File "/.../venv/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 55, in <listcomp>
   samples = collate_fn([dataset[i] for i in batch_indices])
 File "/.../data.py", line 41, in __getitem__
   sample = self.transform(sample)
 File "/.../transforms.py", line 584, in __call__
   sample = t(sample)
 File "/.../transforms.py", line 1074, in __call__
   img = transform(img)
 File "/.../transforms.py", line 584, in __call__
   sample = t(sample)
 File "/.../transforms.py", line 794, in __call__
   return self.lambd(img)
 File "/.../transforms.py", line 1048, in <lambda>
   transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))
 File "/.../transforms.py", line 462, in adjust_contrast
   enhancer = ImageEnhance.Contrast(img)
 File "/.../venv/lib/python3.6/site-packages/PIL/ImageEnhance.py", line 66, in __init__
   mean = int(ImageStat.Stat(image.convert("L")).mean[0] + 0.5)
 File "/.../venv/lib/python3.6/site-packages/PIL/Image.py", line 879, in convert
   self.load()
 File "/.../venv/lib/python3.6/site-packages/PIL/ImageFile.py", line 220, in load
   raise IOError("image file is truncated")
OSError: image file is truncated

You could add a print statement with the image path into your Dataset to debug, which images are throwing this error.

I’m going to drive home from the office now, later tonight I will try this and update you with my progress.
Thank you very much for your help till now, I’ll write back here later tonight.

1 Like

Ok so I’m finally back at it, yesterday night I tried with some new images but I got the same error.
I tried also printing which image is being loaded by the DataLoader when the error is thrown, but it’s never the same image, so I guess the problem is not in the data itself but in my Dataset class.

Is there something wrong in this code?

import os
import torch.utils.data
from PIL import Image
from PIL import ImageFile


class MyDataset(torch.utils.data.Dataset):

    def __init__(self, root_dir_img, root_dir_gt, transform=None):

        self.root_dir_img = root_dir_img
        self.root_dir_gt = root_dir_gt
        self.transform = transform

        img_names = [os.path.join(root_dir_img, name) for name in os.listdir(root_dir_img)]

        gt_names = [os.path.join(root_dir_gt, name) for name in os.listdir(root_dir_gt)]

        self.img_files = []
        self.gt_files = []

        for i in range(len(img_names)):
            self.img_files.append(Image.open(img_names[i]))
            self.gt_files.append(Image.open(gt_names[i]))

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):

        img = self.img_files[idx]
        gt = self.gt_files[idx]

        sample = {'image': img, 'mask': gt}

        if self.transform:
            sample = self.transform(sample)
            img = sample['image']
            gt = sample['mask']

        return img, gt

I’m not sure how many images you have, but could you move the Image.open function to __getitem__?
Usually you should get a warning, of too many files are open, so this shouldn’t be an issue, but we could try that.

Also, I still don’t know, what your self.transform function is. It can’t be the train_transform you posted, since you are using a dict, which shouldn’t work.

Could you post the code of transform?

This is my new dataset class:

class MyDataset(Dataset):

    def __init__(self, root_dir_img, root_dir_gt, transform=None):

        self.root_dir_img = root_dir_img
        self.root_dir_gt = root_dir_gt
        self.transform = transform

        self.img_names = [os.path.join(root_dir_img, name) for name in os.listdir(root_dir_img)]
        self.gt_names = [os.path.join(root_dir_gt, name) for name in os.listdir(root_dir_gt)]

        self.img_names.sort()
        self.gt_names.sort()

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):

        img = Image.open(self.img_names[idx])
        gt = Image.open(self.gt_names[idx])

        sample = {'image': img, 'mask': gt}

        if self.transform:
            sample = self.transform(sample)
            # img = sample['image']  # can I remove this lines?
            # gt = sample['mask']

        return img, gt

And this is the code I’m using to transform (I’m posting only the parts I modified from the basic pytorch transforms.py, tell me if you need something more):

class ColorJitter(object):
    """Randomly change the brightness, contrast and saturation of an image.

    Args:
        brightness (float): How much to jitter brightness. brightness_factor
            is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
        contrast (float): How much to jitter contrast. contrast_factor
            is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
        saturation (float): How much to jitter saturation. saturation_factor
            is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
        hue(float): How much to jitter hue. hue_factor is chosen uniformly from
            [-hue, hue]. Should be >=0 and <= 0.5.
    """

    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue

    @staticmethod
    def get_params(brightness, contrast, saturation, hue):
        """Get a randomized transform to be applied on image.

        Arguments are same as that of __init__.

        Returns:
            Transform which randomly adjusts brightness, contrast and
            saturation in a random order.
        """
        transforms = []
        if brightness > 0:
            brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
            transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))

        if contrast > 0:
            contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
            transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))

        if saturation > 0:
            saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
            transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))

        if hue > 0:
            hue_factor = np.random.uniform(-hue, hue)
            transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))

        np.random.shuffle(transforms)
        transform = Compose(transforms)

        return transform

    def __call__(self, sample):
        """
        Args:
            img (PIL Image): Input image.

        Returns:
            PIL Image: Color jittered image.
        """
        img, mask = sample['image'], sample['mask']
        transform = self.get_params(self.brightness, self.contrast,
                                    self.saturation, self.hue)
        img = transform(img)

        return {'image': img, 'mask': mask}


# ...


class RandomResizedCrop(object):
    """Crop the given PIL Image to random size and aspect ratio.

    A crop of random size of (0.08 to 1.0) of the original size and a random
    aspect ratio of 3/4 to 4/3 of the original aspect ratio is made. This crop
    is finally resized to given size.
    This is popularly used to train the Inception networks.

    Args:
        size: expected output size of each edge
        interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        self.size = (size, size)
        self.interpolation = interpolation

    @staticmethod
    def get_params(img):
        """Get parameters for ``crop`` for a random sized crop.

        Args:
            img (PIL Image): Image to be cropped.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for a random
                sized crop.
        """
        for attempt in range(10):
            area = img.size[0] * img.size[1]
            target_area = random.uniform(0.08, 1.0) * area
            aspect_ratio = random.uniform(3. / 4, 4. / 3)

            w = int(round(math.sqrt(target_area * aspect_ratio)))
            h = int(round(math.sqrt(target_area / aspect_ratio)))

            if random.random() < 0.5:
                w, h = h, w

            if w <= img.size[0] and h <= img.size[1]:
                i = random.randint(0, img.size[1] - h)
                j = random.randint(0, img.size[0] - w)
                return i, j, h, w

        # Fallback
        w = min(img.size[0], img.size[1])
        i = (img.size[1] - w) // 2
        j = (img.size[0] - w) // 2
        return i, j, w, w

    def __call__(self, sample):
        """
        Args:
            img (PIL Image): Image to be flipped.

        Returns:
            PIL Image: Randomly cropped and resize image.
        """
        i, j, h, w = self.get_params(sample['image'])
        return resized_crop(sample, i, j, h, w, self.size, self.interpolation)


# ...


class RandomHorizontalFlip(object):
    """Horizontally flip the given PIL Image randomly with a probability of 0.5."""

    def __call__(self, sample):
        """
        Args:
            img (PIL Image): Image to be flipped.

        Returns:
            PIL Image: Randomly flipped image.
        """
        if random.random() < 0.5:
            return hflip(sample)
        return sample


class RandomVerticalFlip(object):
    """Vertically flip the given PIL Image randomly with a probability of 0.5."""

    def __call__(self, sample):
        """
        Args:
            img (PIL Image): Image to be flipped.

        Returns:
            PIL Image: Randomly flipped image.
        """
        if random.random() < 0.5:
            return vflip(sample)
        return sample


def hflip(sample):
    """Horizontally flip the given PIL Image.

    Args:
        sample (PIL Image): Image to be flipped.

    Returns:
        PIL Image:  Horizontall flipped image.
    """

    img, mask = sample['image'], sample['mask']

    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    img = img.transpose(Image.FLIP_LEFT_RIGHT)
    mask = mask.transpose(Image.FLIP_LEFT_RIGHT)

    return {'image': img, 'mask': mask}


def vflip(sample):
    """Vertically flip the given PIL Image.

    Args:
        img (PIL Image): Image to be flipped.

    Returns:
        PIL Image:  Vertically flipped image.
    """

    img, mask = sample['image'], sample['mask']

    if not _is_pil_image(img):
        raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

    img = img.transpose(Image.FLIP_TOP_BOTTOM)
    mask = mask.transpose(Image.FLIP_TOP_BOTTOM)

    return {'image': img, 'mask': mask}


# ...


class ToTensor(object):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """

    def __call__(self, pic):
        """
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """
        return to_tensor(pic)


def to_tensor(sample):
    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    See ``ToTensor`` for more details.
    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
    Returns:
        Tensor: Converted image.
    """

    pic, mask = sample['image'], sample['mask']
    if not(_is_pil_image(pic) or _is_numpy_image(pic)):
        raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))

    if isinstance(pic, np.ndarray):
        # handle numpy array
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        # backward compatibility
        if isinstance(img, torch.ByteTensor):
            img = img.float()

        return {'image': pic, 'mask': mask}

    if accimage is not None and isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        pic = torch.from_numpy(nppic)
        return {'image': pic, 'mask': mask}

    # handle PIL Image
    if pic.mode == 'I':
        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
    elif pic.mode == 'I;16':
        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
    elif pic.mode == 'F':
        img = torch.from_numpy(np.array(pic, np.float32, copy=False))
    else:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if pic.mode == 'YCbCr':
        nchannel = 3
    elif pic.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img, torch.ByteTensor):
        img = img.float() / 255.0
        # img = img.float()

    # handle PIL Image
    if mask.mode == 'I':
        img2 = torch.from_numpy(np.array(mask, np.int32, copy=False))
    elif mask.mode == 'I;16':
        img2 = torch.from_numpy(np.array(mask, np.int16, copy=False))
    elif mask.mode == 'F':
        img2 = torch.from_numpy(np.array(mask, np.float32, copy=False))
    else:
        img2 = torch.ByteTensor(torch.ByteStorage.from_buffer(mask.tobytes()))
    # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
    if mask.mode == 'YCbCr':
        nchannel = 3
    elif mask.mode == 'I;16':
        nchannel = 1
    else:
        nchannel = len(mask.mode)
    img2 = img2.view(mask.size[1], mask.size[0], nchannel)
    # put it from HWC to CHW format
    # yikes, this transpose takes 80% of the loading time/CPU
    img2 = img2.transpose(0, 1).transpose(0, 2).contiguous()
    if isinstance(img2, torch.ByteTensor):
        img2 = img2.float()

    return {'image': img, 'mask': img2}

The code looks beautiful. At least while skimming through it I couldn’t find any issues.

Since the error appears on different images, we should find the reason for it.
Could you remove the transformation from all images and try it again? Also, try to remove all multiprocessing.