Distributed: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 4; expected version 3 instead

I have converted a repo:

to use the distributed framework of pytorch, but I get this weird error:

-- Process 3 terminated with the following error:
Traceback (most recent call last):
  File "/home/rachel/miniconda3/envs/pt_models/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
    fn(i, *args)
  File "/home/rachel/rachelve/pt_models/mshifted_mp/main.py", line 172, in main_worker
    mean_eer_arr,auc_arr,loss_arr=train_model(model, train_loader, test_loader, train_loader_1, args)
  File "/home/rachel/rachelve/pt_models/mshifted_mp/main.py", line 61, in train_model
    running_loss = run_epoch(model, train_loader_1, optimizer, center, args)
  File "/home/rachel/rachelve/pt_models/mshifted_mp/main.py", line 90, in run_epoch
    loss.backward()
  File "/home/rachel/miniconda3/envs/pt_models/lib/python3.7/site-packages/torch/_tensor.py", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/rachel/miniconda3/envs/pt_models/lib/python3.7/site-packages/torch/autograd/__init__.py", line 149, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

It does not exactly tell me which part of the code errors out in loss.backward so I am finding it difficult to debug. However, I think I need to .clone() the problematic output, but again, I am not sure where!

The relavent model and loss code is as follows:

def run_epoch(model, train_loader, optimizer, center, args):
    total_loss, total_num = 0.0, 0
    for batch_idx, ((img1, img2), _) in enumerate(train_loader):#, desc='Train...'):

        if args.gpu is not None:
            img1 = img1.cuda(args.gpu, non_blocking=True)
            img2 = img2.cuda(args.gpu, non_blocking=True)

        optimizer.zero_grad()

        out_1 = model(img1)
        out_2 = model(img2)
        out_1 = out_1 - center
        out_2 = out_2 - center

        center_loss = ((out_1 ** 2).sum(dim=1).mean() + (out_2 ** 2).sum(dim=1).mean())
        loss = contrastive_loss(out_1, out_2) + center_loss

        loss.backward()

        optimizer.step()

        total_num += img1.size(0)
        total_loss += loss.item() * img1.size(0)

    return total_loss / (total_num)
def contrastive_loss(out_1, out_2):
    out_1 = F.normalize(out_1, dim=-1)#.clone()
    out_2 = F.normalize(out_2, dim=-1)#.clone()
    bs = out_1.size(0)
    temp = 0.35 #0.25   #0.35 gives increasing auroc
    # [2*B, D]
    out = torch.cat([out_1, out_2], dim=0)
    # [2*B, 2*B]
    sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temp)
    mask = (torch.ones_like(sim_matrix) - torch.eye(2 * bs, device=sim_matrix.device)).bool()
    # [2B, 2B-1]
    sim_matrix = sim_matrix.masked_select(mask).view(2 * bs, -1)

    # compute loss
    pos_sim = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temp)
    # [2*B]
    pos_sim = torch.cat([pos_sim, pos_sim], dim=0)
    loss = (- torch.log(pos_sim / sim_matrix.sum(dim=-1))).mean()
    return loss
def train_model(model, train_loader, test_loader, train_loader_1, args):
    model.eval()
    auc_arr,loss_arr=[],[],[]
    auc, feature_space = get_score(model, args, train_loader, test_loader)
    print('Epoch: {}, AUROC is: {}'.format(0, auc))
    auc_arr.append(auc)
    mean_eer_arr.append(mean_eer)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.00005)
    center = torch.FloatTensor(feature_space).mean(dim=0)
    center = F.normalize(center, dim=-1)
    center = center.cuda(args.gpu, non_blocking=True)
    for epoch in range(args.epochs):
        running_loss = run_epoch(model, train_loader_1, optimizer, center, args)
        print('Epoch: {}, Loss: {}'.format(epoch + 1, running_loss))
        loss_arr.append(running_loss)
        auc, _ = get_score(model, args, train_loader, test_loader)
        print('Epoch: {}, AUROC is: {}'.format(epoch + 1, auc))
        auc_arr.append(auc)
    return auc_arr,loss_arr

the distributed parts of the code was adapted from here:

the model code is:

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = models.resnet152(pretrained=True)
        self.backbone.fc = torch.nn.Identity()
        freeze_parameters(self.backbone, train_fc=False)
    def forward(self, x):
        z1 = self.backbone(x)
        z_n = F.normalize(z1, dim=-1)
        return z1

def freeze_parameters(model, train_fc=False):
    for p in model.conv1.parameters():
        p.requires_grad = False
    for p in model.bn1.parameters():
        p.requires_grad = False
    for p in model.layer1.parameters():
        p.requires_grad = False
    for p in model.layer2.parameters():
        p.requires_grad = False
    if not train_fc:
        for p in model.fc.parameters():
            p.requires_grad = False

I think it fails here: z1 = self.backbone(x)

I tried using clone around it etc, but it does not seem to help.