I have a problem with torch.max

Hello, I have a problem with torch.max operation.

When I used torch.max with 1x1x1xn tensor, like

import torch

a = torch.randn(1, 1, 1, 5)
a.max(0)

It produces a tuple of two 1x1x1xn tensor. (One for maximum value, and another for indices)

However, if i try the same operation with n = 1, like

import torch

b = torch.randn(1, 1, 1, 1)
b.max(0)

torch.max produces a tuple of two 1x1x1 tensor.

On the other hand

import torch

c = torch.randn(2, 1, 1, 1)
c.max(0)

it produces a tuple of two 1x1x1 tensor.

I think b and c work correctly, and

a.max(0)

should produce a tuple of 1x1xn tensor. What’s wrong with it?

I think this is a bug that was recently fixed in master (typing from the phone, else I’d get you the reference)

1 Like

Thank you. I will try the newest version.