My current approach is the following, where x_in is the image of shape [batch_size, channels, height, width], fs is the length of the side of the square I want to sample, and n is the number of samples I want to extract from the current batch.
def sample_patches(self, x_in, fs, n): all_patches = x_in.unfold(2, fs, 1).unfold(3, fs, 1).transpose(1, 3).contiguous().view(-1, x_in.size(1) * fs * fs) # print(all_patches.size()) n_sampling_patches = min(all_patches.size(0), n) indices = torch.randperm(all_patches.size(0))[:n_sampling_patches] indices = indices.cuda() patches = all_patches[indices] return patches
This code works but is quite slow. I think because the “contiguous” does a lot of copying.
Is there a more efficient way to achieve this?