Multiple forward passes followed by one backward pass -- DistributedDataParallel fails while DataParallel works

I am doing 2 forward passes on a resnet and trying to compute the gradients using the outputs from the first forward pass. When using multiple GPUs this works when the model is wrapped through nn.DataParallel but not when wrapped through nn.DistributedDataParallel. Below you can find my code.

import torch
import torchvision
import torch.backends.cudnn as cudnn
import torchvision.models as models

import utils.distributed as dist

def main():
    # Get the current device as set for current distributed process.
    # Check `launch` function in `utils.distributed` module.
    device = torch.cuda.current_device()

    # create model
    model = models.resnet50().cuda(device)
    batch_size = 32

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss().to(device)

    cudnn.benchmark = True

    # Wrap model in DDP if using more than one processes.
    if dist.get_world_size() > 1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[device], find_unused_parameters=True
        # Using DataParallel works fine
        # model = torch.nn.DataParallel(
        #     model, device_ids=[device]
        # )

    ip_1 = torch.rand(batch_size,3,224,224).cuda(device)
    op_1 = model(ip_1)
    target_1 = torch.zeros(batch_size, dtype=torch.long).cuda(device)

    ip_2 = torch.rand(batch_size,3,224,224).cuda(device)
    op_2 = model(ip_1)
    target_2 = torch.zeros(batch_size, dtype=torch.long).cuda(device)

    # loss for the first example
    loss = criterion(op_1,target_1)
    loss.backward()    #----------> Fails here when DDP is used

if __name__ == "__main__":

The error I get is

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 4; expected version 3 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

In case you need to look at, here it is. My torch version is '1.7.0a0+8deb4fe'

Thanks in advance for your help

@ramprasaath This looks like a bug, could you please file an issue at