Within this paper, when injecting real and generated images into the Discriminator, they apply data augmentation to each image and compute the loss value according to the following algorithm.
If I only want to perform data augmentation on the real images, I can do so during batch data creation in torch.utils.data.Dataset, but how can I perform data augmentation on a Tensor generated from the Generator?
The following methods involve sending data between the CPU and the GPU, which causes some overhead and reduces computational speed.
Obviously, as you have mentioned, transferring every tensor between cpu and gpu to use available methods which only works for PIL images are is proper.
As far as I know, there is no built-in function for Crop but for others we have a solution.
One solution is to copy the source code and just change the PIL input to tensor. For instance, here is the implementation of RandomResizedCrop
For cropping, indexing just works fine and for resizing there is a built in function. The only issue is random generation for crop which can be copy pasted from source code.
Here is what I have changed that works for a single 3D tensor:
# a arbitrary 3D input
x = torch.ones((3, 100, 100))*255
x[:, 25:75, 25:75] = 0
scale=(0.08, 1.0)
ratio=(3. / 4., 4. / 3.)
width, height = x.shape[-2], x.shape[-1]
size = (64, 64)
area = height * width
for _ in range(10):
target_area = random.uniform(*scale) * area
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
aspect_ratio = math.exp(random.uniform(*log_ratio))
z = None
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
z = i, j, h, w
break
if z is None:
# Fallback to central crop
in_ratio = float(width) / float(height)
if (in_ratio < min(ratio)):
w = width
h = int(round(w / min(ratio)))
elif (in_ratio > max(ratio)):
h = height
w = int(round(h * max(ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
z = i, j, h, w
resized = F.interpolate(x[:, i:h, j:w].unsqueeze(0), size=size, mode='bicubic')[0]
Also you can wrap your code in the form of source code I referenced.
That is probably making more sense as i+h or j+w never exceeds original height or width and those are only a proportion of original sizes.
Although the code I provided needs few other fixes, for instance it uses external libraries that need to be replaced by torch for working on tensors.