How does broadcasting work when the input of torch.max is two tensors of different shape?

How does broadcasting work when the input of torch.max is two tensors of different shape? For example, say a.shape=(2, 3) and b.shape=(2, 3), then what’s the meaning of torch.max(a[:, None, :], b)? It 's confusing to me that element-wise maximum is applied on two tensors of different shape.

Let´s consider the following example

a = torch.arange(6).reshape(2, 3)
# tensor([[0, 1, 2],
#         [3, 4, 5]])

b = torch.ones(2, 3) * 3
# tensor([[3., 3., 3.],
#         [3., 3., 3.]])

In your example, the shape of a is modified as following

print(a[:, None, :].shape)
# torch.Size([2, 1, 3])

And as you correctly mentioned, it will result in broadcasting, thus the new dimension will be repeated. If we print only the result of adding this dimension and repeating it to the resulting shape of your example we get ↓, where dim 1 is being repeated.

print(a[:, None, :].repeat(1, 2, 1))
# tensor([[[0, 1, 2],
#          [0, 1, 2]],
#
#         [[3, 4, 5],
#          [3, 4, 5]]])

If we now perform the operation that you suggested of max, the tensor b is compared to each of the “batches”, where each “batch” has the same size as b. In this trivial example we can see that the first batch is all smaller than b, thus the maximum will be 3. The second “batch” is either equal or bigger than 3.

c = torch.max(a[:, None, :], b)
# tensor([[[3., 3., 3.],
#         [3., 3., 3.]],
#
#        [[3., 4., 5.],
#         [3., 4., 5.]]])

Hope this is clear and helps :smile: