Is there a PyTorch analogue of Lua’s torch.cmax(tensor, value)
?
I checked the docs torch.max
seems to want another tensor as an argument and doesn’t cut it.
There are workarounds, but an easy direct way of doing it would be more concise and probably efficient.
Thanks!
2 Likes
elanmart
(Marcin Elantkowski)
2
One way to accomplish this would be with the torch.max()
function you mentioned:
X = th.FloatTensor([
[-1, 2],
[2, -1]
])
scalar = th.FloatTensor([1])
th.max(X, scalar.expand_as(X))
gives you
1 2
2 1
[torch.FloatTensor of size 2x2]
Please keep in mind that I’m playing with PyTorch for few days only, so that may not be the optimal solution.
EDIT I remember one of the devs mentioned they’re considering adding a autograd.Scalar
, I guess it would make things nicer in this case.
1 Like
Yep, that’s workaround I’m currently using, but it’s a little too verbose IMO compared to cmax
.
smth
5
torch.clamp(tensor, max=value) # cmin
torch.clamp(tensor, min=value) # cmax
Edited to reflect what vadim said below
10 Likes
Seems it’s the other way around:
torch.clamp(tensor, min=value)
is cmax
and
torch.clamp(tensor, max=value)
is cmin
.
It works but is a little confusing at first.
fmassa
(Francisco Massa)
7
cmax
/cmin
was removed to simplify the API, after some comments we received.
https://github.com/pytorch/pytorch/pull/455
Maybe someone is looking for this. Is there a way to do this without that much code?
def min_(value1, value2):
if is_number(value1) and isinstance(value2, torch.Tensor):
return torch.clamp(value2, max=value1)
elif is_number(value2) and isinstance(value1, torch.Tensor):
return torch.clamp(value1, max=value2)
elif isinstance(value1, torch.Tensor) and isinstance(value2, torch.Tensor):
return torch.min(value1, value2)
def max_(value1, value2):
if is_number(value1) and isinstance(value2, torch.Tensor):
return torch.clamp(value2, min=value1)
elif is_number(value2) and isinstance(value1, torch.Tensor):
return torch.clamp(value1, min=value2)
elif isinstance(value1, torch.Tensor) and isinstance(value2, torch.Tensor):
return torch.max(value1, value2)
min_(2, torch.tensor(3))
min_(torch.tensor(3), 2)
min_(torch.tensor(3), torch.tensor(2))