Run user defined function on multi-GPU

Hello,
How can I run a user-defined function on multi-GPU in parallel? As a toy example, given function f, I would like to run f(a) on cuda:0 and f(b) on cuda:1 in parallel:

def f(x):
    return torch.matmul(x, x)

a = torch.rand((5, 5))
b = torch.rand((5, 5))

f(a)
f(b)

Thanks!

you just need to allocate a and b in the gpus that you want and pytorch will do the trick.