Dear guys,
I figure out that some functions in PyTorch like torch.max() or torch.sum() don’t have keepdim
argument thought it is available in the document.
For example, if I run the following code:
import numpy as np
import torch
a = torch.Tensor(np.arange(6).reshape(2, 3))
print torch.max(a, keepdim=True)
I will get the error:
print torch.max(a, keepdim=True)
TypeError: torch.max received an invalid combination of arguments - got (torch.FloatTensor, keepdim=bool), but expected one of:
* (torch.FloatTensor source)
* (torch.FloatTensor source, torch.FloatTensor other)
didn't match because some of the keywords were incorrect: keepdim
* (torch.FloatTensor source, int dim)
Do you have any solution for this problem ?
Moreover, when I run
torch.max(a)
I get a float number but I expect it to be a 2-tuple. The first is a Tensor with max value, the second is the index. I think the code should be checked.