Hello,
I was wondering about loss function defined with a simple python definition, say for example
def loss(arg):
return arg.abs().mean()
What is the right way to make it run on multi gpu ?
My current way would be to wrap it inside a Module
, that can then be wrapped inside a DataParallel
class module_wrapper(nn.Module):
def __init__(self):
return
def forward(self, arg):
return loss(arg)
my_parallel_loss = nn.DataParallel(module_wrapper())
I don’t find this piece of code very compact nor readable. Would a DataParallel
for functions like this one (which can be much longer) be a good idea ? At first I thought a decorator could be a good idea, but that would require the creation of a new wrapper module for each call which is probably not what we want.
Thanks !