Question about torch.max

I have a question for the following code:

tensor_1 = torch.randint(0, 10, [5, 2], dtype=torch.float32)
tensor_2 = torch.randint(0, 10, [4, 2], dtype=torch.float32)
max_1 = torch.max(tensor_1[:, None, :1], tensor_2[:, :1])

print(max_1.size())

The output is: torch.Size([5, 4, 1])

How this is computed?

Hi Chichi!

The short answer is pytorch’s broadcasting.

The arguments to max() have the following shapes:

tensor_1[:, None, :1].shape = torch.Size([5, 1, 1])
tensor_2[:, :1].shape =  torch.Size([4, 1])

When broadcast, the two trailing 1s line up, the middle 1 of the first
argument gets broadcast to match the 4 of the second argument, and
the “missing” 5 in the second argument gets broadcast to match the
leading 5 of the first argument. Thus your final shape is [5, 4, 1].

Best.

K. Frank

1 Like