How to modify the order input for save_image in torch_vision

I want to concatenate 6 inputs to one images such that the first column shows the a_real_test,b_real_test, the second column shows b_fake_test,a_fake_test and the third column shows a_recon_test,b_recon_test. I used the below code and it worked well for the batch size of 1. It means the shape of these inputs is ([1, 3, 256, 256]). However, it shows a wrong order if the batch size bigger than 1 such as ([4, 3, 256, 256]) . How can I change the cat function to deal with my problem

pic = (torch.cat([a_real_test, b_fake_test, a_recon_test, b_real_test, a_fake_test, b_recon_test], dim=0).data + 1) / 2.0
torchvision.utils.save_image(pic, args.results_dir+'/sample.jpg', nrow=3)