Concat images iteratively

Hi everyone, suppose I’ve two tensors holding N RGB images of size 16x32 so:

Input_0 = N × 3 × 32 × 16
Input_1 = N × 3 × 32 × 16

now I want to iteratively concat these images (e.g first image of input_0 with all N images of input_1, second image of input_0 with all N images of input_1, and so on…). In order to get an output like:

Output = N × N × 3 × 32 × 32

How can I do that?