Dataloader member variables not changing

I’m trying to add random scaling augmentation to my training loop. I tried to do this by adding a member function that selects a random scaling factor on each iteration so that all the images in the batch are changed at the same scale, as to keep the dimensions all the same for that batch. However when using the debugger, I notice that the sizes aren’t actually changing. If I set a breakpoint in the resample function, it changes correctly, but when setting breakpoints in the transformation function, it is always at the original scale.

In self._training_loader.dataset.resample_scale(), am I just calling a copy of the class as such I am not actually modifying the actual one used in the training loop?

    def resample_scale(self, reset=False):
        if hasattr(self, 'scale_range') and not reset:
            scale = random.uniform(*self.scale_range)
            scale_func = lambda x: int(scale * x / 32.0) * 32
            self.output_shape = [scale_func(x) for x in self.base_size]
        else:
            self.output_shape = self.base_size

### snippet from syncrhonised transforms
        for key, data in epoch_data.items():
            if key in ["l_img", "r_img", "l_seq", "r_seq"]:
                data = data.resize(self.output_shape, Image.BILINEAR)
                epoch_data[key] = torchvision.transforms.functional.to_tensor(data)
            elif key == "seg":
                data = data.resize(self.output_shape, Image.NEAREST)
                epoch_data[key] = self._seg_transform(data)
            elif key in ["l_disp", "r_disp"]:
                data = data.resize(self.output_shape, Image.NEAREST)
                epoch_data[key] = self._depth_transform(data)

### snippet from training loop
        for batch_idx, data in enumerate(self._training_loader):
            self._training_loader.dataset.resample_scale()

If you are using multiple workers (via num_workers >= 1) you shouldn’t change attributes of the Dataset stored in the DataLoader, as this would most liktly only manipulate the copy of the Dataset and won’t be reflected on the original Dataset.
In your code snippet you are using self._training_loader.dataset.resample_scale() inside the self._training_loader loop, which would be the case.

If you need to manipulate the underlying dataset, you could try to apply the changes before starting the next epoch.

This is for per iteration image scaling. Would a better idea be to define a custom collate function for the dataloader where it does the scaling there? There’s also a bunch of auxillary things such as multiplying scaling factors to ground truth optical flow which I also have to intergrate into it.

The main problem I am having is that I am always getting divergence between training and validation accuracy for segmentation and one of the augmentations I am missing from the paper I’m branching from is random scaling. I’m very memory limited with only a 1070Ti so I operate at a significantly lower resolution and batch size compared to what the literature had which had a bunch of V100’s…could this also be a significant contributing factor?

Example output from cityscapes validation set below:

I’m not completely sure how your code works.
Are you defining a new rescaling size for a complete epoch?
If so, you could set the new size before the epoch is started and thus before the copies of the dataset are executed in the DataLoader loop.

Yes, the resolution as well as a different batch size might yield different results.
That being said, you might be able to change some hyperparameters to get similar results, but you would have to further experiment with the training.

Also, you could try to use torch.utils.checkpoint to trade compute for memory in order to increase the batch size or the spatial resolution.

I would’ve thought it would be best to have a new scaled size for each batch when training, which is why I have the resample_scale within the training iteration loop (not the epoch loop). Do you think varying the scaling per epoch would also be reasonably effective for data augmentation?

I also use random rotation (usually +/- 5 deg) and brightness (+/- 20%). I might try adding saturation as well. On some older architectures I had, I added a gaussian noise layer during training, I might add this back in between the encoder and the seg/flow/depth heads.

I’ll also have a look at torch.utils.checkpoint to increase batch size. I am already using AMP to take advantage of the smaller memory footprint with fp16 (even if I don’t get the other benefits since I’m on Pascal). I ordered a personal RTX 3090 which should hopefully come in a few weeks. Although this is all for a final year undergraduate project which is due in two months so getting some good training in by the time the 3090 arrives is really pushing it…

If you want to apply a new scaling for the complete batch you could either resample the data inside the training loop or use an approach, where the sampler might pass the index as well as the new scaling to the Dataset.
I haven’t tested this approach so let me know, if you get stuck.

It is done, and it works! I implemented a batch sampler that gives a tuple pair of image idx and scaling factor.

import random

from torch.utils.data import Sampler
from torch.utils.data import SequentialSampler
from torch._six import int_classes as _int_classes

class BatchSamplerRandScale(Sampler):
    r"""Extending the Batch Sampler to also pass a scale factor for
        random scale between a list of ranges.

    Args:
        sampler (Sampler or Iterable): Base sampler. Can be any iterable object
            with ``__len__`` implemented.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
        scale_range (List): The range in which will be the sample will be randomly scaled

    Example:
        >>> list(BatchSamplerRandScale(SequentialSampler(range(10)), batch_size=3, drop_last=False, scale_range=[0.5,1]))
        [[(0, 0.65), (1, 0.65), (2, 0.65)], [(3, 0.8), (4, 0.8), (5, 0.8)], [(6, 0.93), (7, 0.93), (8, 0.93)], [(9, 0.54)]]
    """

    def __init__(self, sampler, batch_size, drop_last, scale_range):
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last
        assert len(scale_range) == 2
        self.scale_range = scale_range

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                scale_factor = random.uniform(*self.scale_range)
                batch = [(x, scale_factor) for x in batch]
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            scale_factor = random.uniform(*self.scale_range)
            batch = [(x, scale_factor) for x in batch]
            yield batch

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        return (len(self.sampler) + self.batch_size - 1) // self.batch_size

if __name__ == "__main__":
    test = list(BatchSamplerRandScale(SequentialSampler(range(10)), batch_size=3, drop_last=False, scale_range=[0.5, 1]))
    print(test)

A short example of use in a dataset is as follows


dataloader = torch.utils.data.DataLoader(
    training_dataset,
    num_workers=n_workers,
    batch_sampler=BatchSamplerRandScale(
        sampler=RandomSampler(training_dataset),
        batch_size=dataset_config.batch_size,
        drop_last=dataset_config.drop_last,
        scale_range=dataset_config.augmentations.rand_scale
     )
)

class CustomDataset(torch.utils.data.Dataset):
    def __getitem__(self, idx):
        if isinstance(idx, tuple):
            idx, self.scale_factor = idx

        ider_data = get_iter_data(idx)
        _sync_transform(iter_data)
        return iter_data

    def _sync_transform(self, iter_data):
        scale_func = lambda x: int(self.scale_factor * x / 32.0) * 32
        self.output_shape = [scale_func(x) for x in self.base_size]
        for key, data in iter_data.items():
            data = data.resize(self.output_shape, Image.BILINEAR)
            iter_data[key] = torchvision.transforms.functional.to_tensor(data)
1 Like