I am doing 2 forward passes on a resnet and trying to compute the gradients using the outputs from the first forward pass. When using multiple GPUs this works when the model is wrapped through nn.DataParallel but not when wrapped through nn.DistributedDataParallel. Below you can find my code.
import torch import torchvision import torch.backends.cudnn as cudnn import torchvision.models as models import utils.distributed as dist def main(): # Get the current device as set for current distributed process. # Check `launch` function in `utils.distributed` module. device = torch.cuda.current_device() # create model model = models.resnet50().cuda(device) batch_size = 32 # define loss function (criterion) and optimizer criterion = torch.nn.CrossEntropyLoss().to(device) cudnn.benchmark = True # Wrap model in DDP if using more than one processes. if dist.get_world_size() > 1: dist.synchronize() model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[device], find_unused_parameters=True ) # Using DataParallel works fine # model = torch.nn.DataParallel( # model, device_ids=[device] # ) ip_1 = torch.rand(batch_size,3,224,224).cuda(device) op_1 = model(ip_1) target_1 = torch.zeros(batch_size, dtype=torch.long).cuda(device) ip_2 = torch.rand(batch_size,3,224,224).cuda(device) op_2 = model(ip_1) target_2 = torch.zeros(batch_size, dtype=torch.long).cuda(device) # loss for the first example loss = criterion(op_1,target_1) loss.backward() #----------> Fails here when DDP is used if __name__ == "__main__": dist.launch( main, num_machines=1, num_gpus_per_machine=2, machine_rank=0, dist_url='tcp://localhost:10001', dist_backend='nccl' )
The error I get is
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor ] is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
In case you need to look at
distributed.py, here it is. My torch version is
Thanks in advance for your help