Compute cartesian product for batched tensor

I have the following tensor

x = torch.tensor([[[[ 4,  6],
                    [12, 10]],
                   [[20, 22],
                    [28, 30]]],
                  [[[ 1,  6],
                    [13, 15]],
                   [[16, 18],
                    [29, 26]]]])

with dimensions (batch_size, channels, height, width). Now I want to apply torch.cartesian_prod() to each element of the batch. I can use .flatten(start_dim=0) to get a one-dimensional tensor for each batch element with shape (batch_size, channels*height*width). However, torch.cartesian_prod() is only defined for one-dimensional tensors. Is there a workaround to compute the cartesian product for each batch dimension?

Currently I use

batch_size = 2
indices = list()
for batch in range(batch_size):
        indices.append(torch.cartesian_prod(x[batch], x[batch]))

which is not really elegant.