I have the following tensor
x = torch.tensor([[[[ 4, 6], [12, 10]], [[20, 22], [28, 30]]], [[[ 1, 6], [13, 15]], [[16, 18], [29, 26]]]])
(batch_size, channels, height, width). Now I want to apply
torch.cartesian_prod() to each element of the batch. I can use
.flatten(start_dim=0) to get a one-dimensional tensor for each batch element with shape
(batch_size, channels*height*width). However,
torch.cartesian_prod() is only defined for one-dimensional tensors. Is there a workaround to compute the cartesian product for each batch dimension?
Currently I use
batch_size = 2 indices = list() for batch in range(batch_size): indices.append(torch.cartesian_prod(x[batch], x[batch]))
which is not really elegant.