Variable for gradient accumulation is changed

Hi, I am training my network using accelerator, which based on torch distributed parallel.
But strangely, I came across the ‘Inplace error’ when I use distributed training.
I thought it over and over but I couldn’t find the reason why.
Can someone help me for this problem?

Code

import torch 
from torch import nn 
import einops 
from accelerate import Accelerator 
def cosin_metric(x1, x2):
    return torch.sum(x1 * x2, dim=1) / (torch.norm(x1, dim=1) * torch.norm(x2, dim=1))

class ArcFaceModelToy(nn.Module):
    def __init__(self):
        super(ArcFaceModelToy, self).__init__()
        self.arcface_model = nn.Linear(3*224*224, 512)
        self.bn1 = nn.BatchNorm2d(512)
        
    def forward(self, x):
        x = einops.rearrange(x, 'b c h w -> b ( c h w )')
        x = self.arcface_model(x) 
        x = einops.rearrange(x, 'b (c h w)-> b c h w',h=1,w=1)
        x = self.bn1(x)
        x = einops.rearrange(x, 'b c h w-> b (c h w)')
        return x
    
class ConvGeneratorToy(nn.Module):
    def __init__(self):
        super(ConvGeneratorToy, self).__init__()
        self.conv_encoder = nn.Linear(224,224)
        
    def forward(self, x):
        x = self.conv_encoder(x)
        return x



class Trainer(object):
    def __init__(self, 
                 split_batches=True):
        
        # basic info 
        # accelerator 
        self.accelerator = Accelerator(
            split_batches = split_batches
        )
        
        self.arcface_model = ArcFaceModelToy()
        self.generator_model = ConvGeneratorToy()
        self.arcface_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.arcface_model)
        self.generator_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.generator_model)
 
        # optimizer G 
        generator_params = list(self.generator_model.parameters())
        self.optimizer_G = torch.optim.Adam(generator_params, lr=1e-2, betas=(0,0.999))
        
        self.arcface_model, self.generator_model= self.accelerator.prepare(self.arcface_model, self.generator_model)
        self.optimizer_G = self.accelerator.prepare(self.optimizer_G)
        
        self.arcface_model.eval()
        
        self.arcface_model.requires_grad_(False)
        
    def train(self):
        device = self.accelerator.device
        # train 
        self.generator_model.train()
        
        for i in range(100):
            tgt_image, src_image = torch.randn(6,3,224,224), torch.randn(6,3,224,224)
            tgt_image = tgt_image.to(device)
            src_image = src_image.to(device)
            
            result_image_first = self.generator_model(tgt_image)
            result_image_second = self.generator_model(src_image)
            
            first_result_face_latent_vector = self.arcface_model(result_image_first)
            second_result_face_latent_vector = self.arcface_model(result_image_second)
            
            src_face_latent_vector = torch.randn_like(first_result_face_latent_vector).to(device)
            tgt_face_latent_vector = torch.randn_like(second_result_face_latent_vector).to(device)
            
            loss_G_face_ID_first = (1-cosin_metric(first_result_face_latent_vector, src_face_latent_vector)).mean()
            loss_G_face_ID_second = (1-cosin_metric(second_result_face_latent_vector, tgt_face_latent_vector)).mean()
            loss_G = (loss_G_face_ID_first + loss_G_face_ID_second)##+ loss_G_hair_ID_first + loss_G_hair_ID_second)
            
            self.optimizer_G.zero_grad()
            self.accelerator.backward(loss_G)
            self.optimizer_G.step()

if __name__=="__main__":
    trainer = Trainer()
    trainer.train()

Error Message

ne of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 3; expected version 2 instead. 
Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Hi @Seungdae_Han,

Given the stacktrace states that there’s a tensor of size [512] at version 3 rather than 2, I’d expected you have an in-place operation somewhere in your code.

Can you try running your code within a torch.autograd.set_detect_anomaly context manager? That’ll help with the debugging!

Hi @AlphaBetaGamma96 , Thankyou for your help!

I also tried torch.autograd.set_detect_anomaly(True) for debugging.

File "trainer_minimal_example.py", line 87, in <module>
    trainer.train()
  File "trainer_minimal_example.py", line 71, in train
    first_result_face_latent_vector = self.arcface_model(result_image_first)
  File "/home/aistudio/anaconda3/envs/HSD_hair_modification/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aistudio/anaconda3/envs/HSD_hair_modification/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/aistudio/anaconda3/envs/HSD_hair_modification/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/aistudio/anaconda3/envs/HSD_hair_modification/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "trainer_minimal_example.py", line 21, in forward
    x = self.bn1(x)
  File "/home/aistudio/anaconda3/envs/HSD_hair_modification/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/aistudio/anaconda3/envs/HSD_hair_modification/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 741, in forward
    return F.batch_norm(
  File "/home/aistudio/anaconda3/envs/HSD_hair_modification/lib/python3.8/site-packages/torch/nn/functional.py", line 2450, in batch_norm
    return torch.batch_norm(
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "trainer_minimal_example.py", line 87, in <module>
    trainer.train()
  File "trainer_minimal_example.py", line 82, in train
    self.accelerator.backward(loss_G)
  File "/home/aistudio/anaconda3/envs/HSD_hair_modification/lib/python3.8/site-packages/accelerate/accelerator.py", line 1683, in backward
    loss.backward(**kwargs)
  File "/home/aistudio/anaconda3/envs/HSD_hair_modification/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/home/aistudio/anaconda3/envs/HSD_hair_modification/lib/python3.8/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I think it is related with the batchnorm, but I cannot fix the problem.

Actually, I added minimalized example code above and one could see that there is no in-place opperation…

Hi Seungdae!

Try the suggestions in the following post and see if they help with your
debugging:

Best.

K. Frank