Suppose I have a module combined some submodule, like:
class Comb(nn.Module):
def __init__(self, encoder, decoder):
super(Comb, self).__init__()
self.encoder = encoder
self.decoder = decoder
When distribute training, I warp the module with DDP:
comb = DDP(comb)
raw_comb = comb.module
Now in training loop, I want to use the submodule of the comb, like:
feature = raw_comb.encoder(input)
out = raw_comb.decoder(feature)
loss = loss_fn(out, label)
loss.backward()
Will this behavior break the DDP module gradient sync?