Remove dataparallel from submodules

Context:

I have a main model netG, which includes a bunch of modules (e.g. nn.conv2D, nn.Sequential and some other modules defined by myself).

Due to some issues on the code, some submodules of netG don’t support data parallel.

So, instead of wrapping the whole netG like this

nn.DataParallel(netG)

I do

nn.DataParallel(nn.conv2D)
nn.DataParallel(nn.Sequential)

for modules that support nn.DataParallel, and leaves other modules as it is (no data parallel)

Problem:

This works fine, but when I save netG, I want to remove the DataParallel model wrapper for the whole model. Clearly I can’t do netG.module.state_dict() because netG is not wrapped so it doesn’t have the attribute module.

Candidate solutions:

I am thinking of recursively removing the DataParallel wrapper for netG’s children, but I am not sure if I can do that in-palce, or I have to construct a new netG object from scratch.

What it the solution?

Thanks!

Maybe it’s not a good practice to separate your modules. But have you try this way to save your separate modules?

torch.save({
            'nn.conv2D_state_dict': nn.conv2D.module.state_dict(),
            'nn.Sequential_state_dict': nn.Sequential.module.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            ...
            }, PATH)

Thank you! I haven’t try that, but that makes perfect sense.

Also, if some sub-modules can’t be parallelized (as in my case), what would be a good practice? Should we not parallelize the whole module at all?

I think DataParallel would work well, if you can make the batch size of your sub-modules at the same position for DataParallel to scatter. :thinking: