I was wondering about loss function defined with a simple python definition, say for example
def loss(arg): return arg.abs().mean()
What is the right way to make it run on multi gpu ?
My current way would be to wrap it inside a
Module, that can then be wrapped inside a
class module_wrapper(nn.Module): def __init__(self): return def forward(self, arg): return loss(arg) my_parallel_loss = nn.DataParallel(module_wrapper())
I don’t find this piece of code very compact nor readable. Would a
DataParallel for functions like this one (which can be much longer) be a good idea ? At first I thought a decorator could be a good idea, but that would require the creation of a new wrapper module for each call which is probably not what we want.