Is there a way to vectorize this code in pytorch?
def myfunc(a, b):
"Return a-b if a>b, otherwise return a+b"
if a > b:
return a - b
else:
return a + b
With numpy
I can simply do
@np.vectorize
def myfunc(a, b):
"Return a-b if a>b, otherwise return a+b"
if a > b:
return a - b
else:
return a + b
such that myfunc(a=[1, 2, 3, 4], b=2)
returns array([3, 4, 1, 2])
.
Is there a way to do the same in pytorch?