Hello everyone,
I’m currently working with a model that follows an encoder-decoder architecture. I need the encoder to be in evaluation mode while the decoder is the component being trained. Running the model on a single GPU is clear to me. However, I face some doubts when I try to use DistributedDataParallel (DDP). Here’s the relevant portion of my code:
# Set up for DDP
args.gpu = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(args.gpu) # set default device
dist_backend = "nccl"
# ... (additional setup) ...
# Initialize the model
model = create_model().to(args.device) # args.device = 'cuda'
model_without_ddp = model
if distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) # args.gpu = 0
model_without_ddp = model.module
# Training loop
for epoch in range(start_epoch, args.epochs):
train_one_epoch(model, model_without_ddp, optimizer, dataloader, args)
evaluate(model, dataloader_test, args)
You’ll notice I’m using both model
and model_without_ddp
. This is based on PyTorch’s reference scripts, which I’ve used to ensure consistent code and minimize condition checks for DDP usage.
def train_one_epoch(model, model_without_ddp, optimizer, dataloader, args):
model.train()
model_without_ddp.encoder.eval() # This replaces: if distributed: model.module.encoder.eval() else model.encoder.eval()
train_loss = 0.0
for x, _ in dataloader:
x = x.to(args.device)
loss = model(x)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model_without_ddp.decoder.parameters(), 1.0) # Replaces: if distributed: model.module.decoder.parameters() else model.decoder.parameters.
optimizer.step()
train_loss += loss.item()
return train_loss
@torch.inference_mode()
def evaluate(model, dataloader_test, args):
model.eval() # entire model is in eval mode
# ... (rest of the evaluation code) ...
return metrics
My first question is: Is this the correct approach for accessing and modifying the internal modules of a model wrapped in DDP, especially when switching between the train()
and eval()
modes?
Additionally, I have another question. In the training loop, I use model(x)
for the forward pass, which implicitly includes both the encoder and decoder. Still, I explicitly use model_without_ddp.decoder.parameters()
(the component I train). I wonder if I should be using model.parameters()
because of some weird stuff inside the DPP class that I don’t know about (although the coder parameters are unnecessary) or if, on the contrary, my methodology is a fine approach.
Thanks