Is it inefficient to apply nn.DataParallel to nn.Module which is composed of sub-module accelerated with nn.DataParallel?

We sometimes reuse existent models to build loss module like perceptual loss and GAN’s adversarial loss. I would like to know that if the existent models are accelerated with nn.DataParallel, it is inefficient to use nn.DataParallel one more time for the loss Module which use the existent models.

For example, model1 and model2 compose cycle loss module as follows. In that case, is the code, “cycle_loss = nn.DataParallel(CycleLoss(model1, model2)).cuda()”, inefficient?

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
	def __init__(self):
		self.conv = nn.Conv2d(1,1,3,1,1)

	def forward(input):
		return F.relu(self.conv(input))

class CycleLoss(nn.Module):
	def __init__(self, model1, model2):
		self.model1 = model1
		self.model2 = model2
		self.l1 = nn.L1Loss()

	def forward(input1, input2):
		loss = self.l1(self.model2(self.model1(input1)), input1)
		loss += self.l1(self.model1(self.model2(input2)), input2)
		return loss[None] # expand dim=0 to concatnate

# make model
model1 = nn.DataParallel(Model()).cuda()
model2 = nn.DataParallel(Model()).cuda()

# loss module
cycle_loss = nn.DataParallel(CycleLoss(model1, model2)).cuda()

# opt
opt = torch.optim.Adam(list(model1.parameters()) + list(model2.parameters()))

# train
for input1, input2 in data_loader:
	loss = cycle_loss(input1, input2)
	loss = loss.mean()

I am afraid that nesting nn.DataParallel makes the code perform scatter and gather at each sub-module needlessly.

Thank you.

Yes, this is not great. The outer nn.DataParallel module will replicate N times. The inner modules will also be replicated N times. I think you’ll end up with N^2 replicas instead of just N. You can add some logging to the forward functions to confirm what really ends up happening.