Context:
I have a main model netG
, which includes a bunch of modules (e.g. nn.conv2D
, nn.Sequential
and some other modules defined by myself).
Due to some issues on the code, some submodules of netG
don’t support data parallel.
So, instead of wrapping the whole netG
like this
nn.DataParallel(netG)
I do
nn.DataParallel(nn.conv2D)
nn.DataParallel(nn.Sequential)
…
for modules that support nn.DataParallel, and leaves other modules as it is (no data parallel)
Problem:
This works fine, but when I save netG
, I want to remove the DataParallel model wrapper for the whole model. Clearly I can’t do netG.module.state_dict()
because netG
is not wrapped so it doesn’t have the attribute module.
Candidate solutions:
I am thinking of recursively removing the DataParallel wrapper for netG
’s children, but I am not sure if I can do that in-palce, or I have to construct a new netG
object from scratch.
What it the solution?
Thanks!