Hello,
I am implementing a custom network, which first computes a latent encoding, and then does some forward passes. I am wondering whether this kind of network can work on nn.DataParallel. I assume nn.DataParallel only does some parallelism on self.forward. Does it also handle the custom functions as well?
def MyNetwork(nn.Module):
def preprocess(self, a):
self.encoding = encode(a)
def forward(self, b):
return self.encoding + self.fc(b)
model = nn.DataParallel(MyNetwork())
# Pre encoding
model.module.preprocess(a)
# Parallel data pass
output1 = model(b1)
output2 = model(b2)