Calling custom methods in the multi-gpu case?

I have multiple modules/objects that have multiple methods other then forward(). Is there a way to do def data_parallel(module, inputs, device_ids=None, output_device=None..) on those methods? Is it only a matter of changing output = module(*input, **kwargs) line in parallel_apply() in to something like output = getattr(module, method_name)(*input, **kwargs)? Would backward() work as excepted after that?

More specifically, I have an object that inherits from nn.Module, has some Parameters, some sub-modules, etc. I have multiple objects like this and call different methods to compute different parts of my loss and then combine them and do backwards pass. I compute different losses at different steps of execution.

Or should I just wrap everything into a giant super-Model that handles all other models, calls their methods as its children and data-parallelize it? And, like, returns different things based on parameters passed to forward()? Sounds like a messy solution.

I actually have that complicated problem with many objectives that are updated dynamically according to some scheduler and that is really a mess, probably :frowning: