multi-GPU for user defined functions

Hello,
I was wondering about loss function defined with a simple python definition, say for example

def loss(arg):
    return arg.abs().mean()

What is the right way to make it run on multi gpu ?
My current way would be to wrap it inside a Module, that can then be wrapped inside a DataParallel

class module_wrapper(nn.Module):
    def __init__(self):
        return
    def forward(self, arg):
        return loss(arg)

my_parallel_loss = nn.DataParallel(module_wrapper())

I don’t find this piece of code very compact nor readable. Would a DataParallel for functions like this one (which can be much longer) be a good idea ? At first I thought a decorator could be a good idea, but that would require the creation of a new wrapper module for each call which is probably not what we want.

Thanks !