So, batch.flip(2) works fine, but how to flip only half of it?
You could index/slice the tensor, flip one part of it, and reconstruct it afterwards.
Is this a good way?
perm = torch.randperm(batch.size(0))
idx = perm[:batch.size(0)//2]
batch[idx] = batch[idx].flip(3)
Yes, your code looks good and seems to work fine:
batch = torch.cat((torch.zeros(4, 2, 4, 2), torch.ones(4, 2, 4, 2)), dim=3)
print(batch)
perm = torch.randperm(batch.size(0))
idx = perm[:batch.size(0)//2]
batch[idx] = batch[idx].flip(3)
print(batch)