I have multiple modules/objects that have multiple methods other then forward()
. Is there a way to do def data_parallel(module, inputs, device_ids=None, output_device=None..)
on those methods? Is it only a matter of changing output = module(*input, **kwargs)
line in parallel_apply()
in parallel_apply.py
to something like output = getattr(module, method_name)(*input, **kwargs)
? Would backward()
work as excepted after that?
More specifically, I have an object that inherits from nn.Module
, has some Parameter
s, some sub-modules, etc. I have multiple objects like this and call different methods to compute different parts of my loss and then combine them and do backwards pass. I compute different losses at different steps of execution.
Or should I just wrap everything into a giant super-Model that handles all other models, calls their methods as its children and data-parallelize it? And, like, returns different things based on parameters passed to forward()
? Sounds like a messy solution.
I actually have that complicated problem with many objectives that are updated dynamically according to some scheduler and that is really a mess, probably
Thanks!