Data Augementation on the fly

Hi everyone.

I have this training set of 2997 samples, where each sample has size 24x24x24x16. I would like to augment it by 24 times through rotation. Ideally the rotation should have been of 90 degrees, thus in order to get 23 different sample (the first one is the orignal) i would have to change the ax of rotation [(0,1), (1,0), (2,0), (0,2)] ecc. Six permutations are required.

Everthing should be on the fly, so once the sample is retrived with __getitem__ , i should generate 24 samples.

By now, i don’t think that an actual augmentation is done (i adjusted the original version of the code i am working on). I printed some tensors and indeed i don’t see any rotation, only the same sample replied.
Plus, i was considering to move everthing to tensors and not working with Numpy arrays, so i need to modify the code anyway.

This the class in which i perform the augmentation on the fly:

class AugmentedDataGenerator(Dataset):
    # It initializes the instance of the class
    def __init__(self, x, y, aug_count=24):
        # features and labels 
        self.x = x
        self.y = y
        self.aug_count = aug_count

    def __len__(self):
        # multiplication for the augmentation
        return len(self.x) * self.aug_count

    def __getitem__(self, idx):
        # Original index (actual divided by the augumentation index)
        index = idx // self.aug_count 
        # rotation index will be in the range from 0 to 23. 
        rotation_index = idx % self.aug_count

        # retriving the sample
        sample_x = self.x[index]
        sample_y = self.y[index]
        aug_x = self._rotate_sample(sample_x, rotation_index)
        aug_y = sample_y

        return aug_x, aug_y

    # This method performs rotation augmentation on the input sample
    def _rotate_sample(self, sample, rotation_index):
        output = np.zeros_like(sample)
        # axes of rotation
        axes = [(1, 2), (0, 2), (0, 1)]
        for axis in axes:
            # rotates the sample
            rotated_sample = np.rot90(sample, k=rotation_index, axes=axis)
            # maximum value across all rotations along different
            output = np.maximum(output, rotated_sample)

        # it holds the sample that has been rotated along all specified axes
        return output

This is the beginning of the training:

 # Loops over each batch in train_loader
    for i, (input, target) in enumerate(train_loader): 
        # Measure data loading time
        data_time.update(time.time() - end)
        # Moves both the input data and the target labels to the CPU or GPU 
        print("batch", i, "and input shape", input.shape)
        input = input.reshape(-1, 16, 24, 24, 24).to(device, non_blocking=True)
        target = target.view(-1, 1).to(device, non_blocking=True)

It would much easier to augment the data once and save to disk, but i runned out of memory.

I tried to apply this transformation:

transform =v2.Compose([

by doing this in the __getitem__ method:

# Retrieving a sample from the dataset given its index
    def __getitem__(self, idx):
        if self.transform is None: 
            features= self.x_data[idx]
            label =  self.y_data[idx]
            index = idx // self.aug_count 
            rotation_index = idx % self.aug_count
            features = torch.stack([torch.tensor(self.transform(self.x_data[index])) for _ in range(24)], dim=0)
            label = self.y_data[index]
        return features, label

But then i have that each batch has its size multiplicated by 24 and the dimensions don’t make sense.
Batch 0 - Shape of the input torch.Size([8, 24, 24, 24, 24, 16])

To conclude, i think i’m losing the thread here. I would like to understand better these concept:

  • how to apply augmentation on the fly formally (in which class / methods)
  • how to deal with the resulting augmented data and the dimensions of the batches

Thank you in advance.

Hi Emma!

Do I understand correctly that each sample is a three-dimensional image with
height = width = depth = 24 (and with 16 channels)?

This makes sense to me – a cube has 24 (non-reflecting) rotations (including
the trivial rotation that doesn’t change anything).

You can count them like this: A cube has six faces, so you have six choices of
which face to rotate to the top. Then you have a choice of four rotations you
can make around the bottom-to-top axis (0, 90, 180, and 270 degrees). Then
six times four gives you 24.

Note that (assuming that I understand your use case correctly) you can’t get
all 24 rotations of the cube by rotating around a specific (0, 1)-style axis. You
have three such axes: bottom-to-top, left-to-right, and front-to-back. Around
each axis you have three non-trivial rotations. (The 0-degree rotation is trivial.)
So you have only ten such “around-an-axis” rotations: the trivial rotation plus
nine (three times three) non-trivial rotations.

I don’t follow your logic here.

This makes sense and is certainly doable (along the general lines sketched in
the code you posted).

Yes, I would recommend writing all of this in pytorch. This will simplify things a
little bit by not using numpy where pytorch can do the same thing. Plus using
pytorch tensors will get you the benefits of a gpu, if you have one.

Yes, rotation_index will run from 0 to 23.

You don’t say what sample_y is. I assume that it is some sort of ground-truth
target. Depending on your use case, you may or may not have to rotate sample_y
along with sample_x.

Here, you are passing rotation_index – which runs up to 23 – to rot90(). This
will only give you four distinct results (for k = 0, 1, 2, 3). (I assume that rot90()
returns the same result for, say, k = 1 and k = 5.) You need a scheme to map your
rotation_index to the 24 distinct rotations of a cube.

I don’t understand why you are taking the maximum() here. If this is really what
you want, could you explain your use case in greater detail?

Performing augmentation on the fly make good sense here. Your general approach
looks sound (although some of the details look wrong to me).


K. Frank

1 Like

Thank you Frank for your quick and extensive the response!

I am going to clarify some aspects because I was hasty in writing the answer.

Yes. Each sample is representing a protein-drug complex that has been voxelized in 3D grids. Each cell of the grid has 16 channels. So in the end each sample has shape (24,24,24,16).

I think that what you explained here was pratically what i was trying to saying very confusingly in the next lines (the ones you didn’t follow).
I posted and was working on because, as you said, it was incorrect.
Here is how I rewrote it. It is basic and not optimized, but I needed a better understanding of how the cube was rotated.

def rotate_sample(sample):
    # Initialize output array to store rotated samples
    output = np.zeros((10,) + sample.shape) 
    # Initialize counter to keep track of the current index
    counter = 0
    # Define axes representing different planes of rotation
    axes = [(0, 1), (1, 2), (0, 2)]
    # Save the original sample
    output[counter] = sample
    counter += 1
    # For angle in [0, 90, 180, 270]
    for plane in axes:
        for angle in [90, 180, 270]:
            # Rotate the sample (k=1,2,3)
            rotated_sample = np.rot90(sample, k=angle//90, axes=plane)

            # Save the rotated sample
            output[counter] = rotated_sample
            counter += 1
    # x times the rotated sample 
    return output

def aug_data_generator(sample_x, sample_y):
    aug_count = 10 # Origin + rotation
    # Arrays for stroing 
    aug_data_x = np.zeros((sample_x.shape[0] * aug_count,) + sample_x.shape[1:])
    aug_data_y = np.zeros((sample_y.shape[0] * aug_count,))
    # Take the given dataset
    for i in range(sample_x.shape[0]):
        # Apply the rotation on each sample
        aug_x = rotate_sample(sample_x[i])
        # Repeat the label
        aug_y = np.repeat(sample_y[i], aug_count)
        # Assign augmented data to the output arrays
        start_index = i * aug_count
        end_index = start_index + aug_count
        aug_data_x[start_index:end_index] = aug_x
        aug_data_y[start_index:end_index] = aug_y

    return aug_data_x, aug_data_y

I decided to keep 10, as you indicated, because I don’t understand how you can get to 24 different rotations of the cube. In fact, my reference code was not augmenting the data. It was creating 23 copies for each sample.

I have to do everything on the fly because I get a memory out of memory if I try to save the data to disk at once. There are some ways to efficiently write a lot of data to disk, but I need to read up on it because I have never done it before.

Thank you for the suggestion. I will try to use tensors and not array. I have a gpu so i will also move the tensors to it.

Yes is the ground-truth target. Is not rotating but it is repeated for each augmented sample.

As you can see i modified these 2 part because they weren’t clear to me either. Especially the maximum was random.

Apart of this problems around the rotation I was more concerned on generating data dynamically or not, and if there were some well defined cases that i could follow for augmentation. This last concern is due to the fact that I am just beginning to learn about the world of deep leraning and Pytorch, so I am trying to keep the code as clean and simple as possible, following the structures found in the Pytorch tutorials and examples.

Thank you a lot!


This is the final version that should work correctly:

class AugmentedDataGeneratorPytorch(Dataset):
    def __init__(self, x, y, aug_count=10):
        self.x = torch.tensor(x, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
        self.aug_count = aug_count

    def __len__(self):
        return len(self.x) * self.aug_count

    def __getitem__(self, idx):
        original_index = idx // self.aug_count
        augmentation_index = idx % self.aug_count
        sample_x = self.x[original_index]
        sample_y = self.y[original_index]
        augmented_sample = self._rotate_sample(sample_x, augmentation_index)
        return augmented_sample, sample_y

    def _rotate_sample(self, sample, augmentation_index):
        if augmentation_index == 0:
            return sample

        axes = [(1, 2), (0, 2), (0, 1)]
        counter = 1
        for plane in axes:
            for angle in [90, 180, 270]:
                if counter == augmentation_index:
                    rotated_sample = torch.rot90(sample, k=angle//90, dims=plane)
                    return rotated_sample
                counter += 1

        raise ValueError("Invalid augmentation index")

Hi Emma!

Your scheme where you use an indexed Dataset (one that implements
__getitem__(self, idx)), where idx encodes both the actual image index
and also the rotation_index, and then use rotation_index to generate
your rotated image on the fly is fine.

(You could also follow the pytorch transforms approach where you would add
a random-cube-rotation transform to your Dataset. Now each idx would map
to a specific real image – no rotation_index encoded in it – but each time you
fetch an image with the same idx, you would get a different randomly-rotated
version of the same real image. This would be an alternative approach – I tend
to like your scheme better.)

If you then access your AugmentedDataGenerator using a DataLoader with
shuffle = True, the random idx that the DataLoader uses to retrieve an image
will map to a specific (but random) rotation of the “real” image, thus giving you
your desired data augmentation. There’s nothing the matter with this scheme.

It’s perfectly legitimate to augment your data using only 10 of the 24 rotations
of the cube. Nothing will break or give you wrong results. However, there really
are 24 rotations of the cube.

Here’s another way to count them:

A cube has 12 edges. Pick one of those edges and think about rotating the cube
so that that chosen edge becomes the lower front edge. But there are two ways
a given edge can become the lower front edge: it can either run left-to-right or
right-to-left. So there are 12 times 2 = 24 rotations of the cube.

Again, it is not possible to get all 24 rotations of the cube by using a single
rot90(). Consider the three (face-to-face) axes of a cube. (You have labelled
them [(1, 2), (0, 2), (0, 1)].) Any rot90 rotation of the cube leaves one
of the three axes unchanged, while swapping the other two with one another.
This makes it easy to construct a specific example of a cube-rotation for which
it is easy to see that it is not a rot90()-rotation.

Consider a rotation that permutes all three axes:

[(1, 2), (0, 2), (0, 1)][(0, 1), (1, 2), (0, 2)]

Such a rotation leaves no axis unchanged, so it can’t be a rot90()-rotation.

It is not necessary to use all of the cube rotations to augment your data, but why
not? In for a penny, in for a pound, as they say.


K. Frank

Thank you again Frank.

I will try augmenting with 24 rotations, even if the training will slow down.
So basically what i have to do is

  • Fix one face as the top face
  • Rotate the cube around the vertical axis passing through the top face

So, for example:

  • (0, 1) represents an edge running from left to right.
  • (1, 0) represents the same edge but described from right to left.

Thank you!