Np.vectorize in pytorch

Is there a way to vectorize this code in pytorch?

def myfunc(a, b):
    "Return a-b if a>b, otherwise return a+b"
    if a > b:
        return a - b
    else:
        return a + b

With numpy I can simply do

@np.vectorize
def myfunc(a, b):
    "Return a-b if a>b, otherwise return a+b"
    if a > b:
        return a - b
    else:
        return a + b

such that myfunc(a=[1, 2, 3, 4], b=2) returns array([3, 4, 1, 2]).

Is there a way to do the same in pytorch?

It seems torch.where() is the way to go:

a = torch.tensor([1, 2, 3, 4])
b = torch.tensor(2)
torch.where(a>b, a-b, a+b)

the code above returns tensor([3, 4, 1, 2]).