RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation when using distributed training

Hi. I’m using the following 1D resnet. I can train it in the non-distributed mode without any error but when switching to data distributed parallel mode I get the gradient computation has been modified by an in-place operation which usually occurs for in-place operations.


class MyConv1dPadSame(nn.Module):
    """
    extend nn.Conv1d to support SAME padding
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
        super(MyConv1dPadSame, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.groups = groups
        self.conv = torch.nn.Conv1d(
            in_channels=self.in_channels, 
            out_channels=self.out_channels, 
            kernel_size=self.kernel_size, 
            stride=self.stride, 
            groups=self.groups)

    def forward(self, x):
        
        net = x
        
        # compute pad shape
        in_dim = net.shape[-1]
        out_dim = (in_dim + self.stride - 1) // self.stride
        p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
        pad_left = p // 2
        pad_right = p - pad_left
        net = F.pad(net, (pad_left, pad_right), "constant", 0)
        
        net = self.conv(net)

        return net
        
class MyMaxPool1dPadSame(nn.Module):
    """
    extend nn.MaxPool1d to support SAME padding
    """
    def __init__(self, kernel_size):
        super(MyMaxPool1dPadSame, self).__init__()
        self.kernel_size = kernel_size
        self.stride = 1
        self.max_pool = torch.nn.MaxPool1d(kernel_size=self.kernel_size)

    def forward(self, x):
        
        net = x
        
        # compute pad shape
        in_dim = net.shape[-1]
        out_dim = (in_dim + self.stride - 1) // self.stride
        p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
        pad_left = p // 2
        pad_right = p - pad_left
        net = F.pad(net, (pad_left, pad_right), "constant", 0)
        
        net = self.max_pool(net)
        
        return net
    
class BasicBlock(nn.Module):
    """
    ResNet Basic Block
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False):
        super(BasicBlock, self).__init__()
        
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.stride = stride
        self.groups = groups
        self.downsample = downsample
        if self.downsample:
            self.stride = stride
        else:
            self.stride = 1
        self.is_first_block = is_first_block
        self.use_bn = use_bn
        self.use_do = use_do

        # the first conv
        self.bn1 = nn.BatchNorm1d(in_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.do1 = nn.Dropout(p=0.5)
        self.conv1 = MyConv1dPadSame(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=self.stride,
            groups=self.groups)

        # the second conv
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.do2 = nn.Dropout(p=0.5)
        self.conv2 = MyConv1dPadSame(
            in_channels=out_channels, 
            out_channels=out_channels, 
            kernel_size=kernel_size, 
            stride=1,
            groups=self.groups)
                
        self.max_pool = MyMaxPool1dPadSame(kernel_size=self.stride)

    def forward(self, x):
        
        identity = x
        
        # the first conv
        out = x
        if not self.is_first_block:
            if self.use_bn:
                out = self.bn1(out)
            out = self.relu1(out)
            if self.use_do:
                out = self.do1(out)
        out = self.conv1(out)
        
        # the second conv
        if self.use_bn:
            out = self.bn2(out)
        out = self.relu2(out)
        if self.use_do:
            out = self.do2(out)
        out = self.conv2(out)
        
        # if downsample, also downsample identity
        if self.downsample:
            identity = self.max_pool(identity)
            
        # if expand channel, also pad zeros to identity
        if self.out_channels != self.in_channels:
            identity = torch.transpose(identity,-1,-2)
            ch1 = (self.out_channels-self.in_channels)//2
            ch2 = self.out_channels-self.in_channels-ch1
            identity = F.pad(identity, (ch1, ch2), "constant", 0)
            identity = torch.transpose(identity,-1,-2)
        
        # shortcut
        out = out + identity

        return out
    
class ResNet1D(nn.Module):
    """
    
    Input:
        X: (n_samples, n_channel, n_length)
        Y: (n_samples)
        
    Output:
        out: (n_samples)
        
    Pararmetes:
        in_channels: dim of input, the same as n_channel
        base_filters: number of filters in the first several Conv layer, it will double at every 4 layers
        kernel_size: width of kernel
        stride: stride of kernel moving
        groups: set larget to 1 as ResNeXt
        n_block: number of blocks
        n_classes: number of classes
        
    """

    def __init__(self, in_channels, base_filters, kernel_size, stride, groups, n_block, n_classes, downsample_gap=2, increasefilter_gap=4, use_bn=True, use_do=True, verbose=False):
        super(ResNet1D, self).__init__()
        
        self.verbose = verbose
        self.n_block = n_block
        self.kernel_size = kernel_size
        self.stride = stride
        self.groups = groups
        self.use_bn = use_bn
        self.use_do = use_do

        self.downsample_gap = downsample_gap # 2 for base model
        self.increasefilter_gap = increasefilter_gap # 4 for base model

        # first block
        self.first_block_conv = MyConv1dPadSame(in_channels=in_channels, out_channels=base_filters, kernel_size=self.kernel_size, stride=1)
        self.first_block_bn = nn.BatchNorm1d(base_filters)
        self.first_block_relu = nn.ReLU()
        out_channels = base_filters
                
        # residual blocks
        self.basicblock_list = nn.ModuleList()
        for i_block in range(self.n_block):
            # is_first_block
            if i_block == 0:
                is_first_block = True
            else:
                is_first_block = False
            # downsample at every self.downsample_gap blocks
            if i_block % self.downsample_gap == 1:
                downsample = True
            else:
                downsample = False
            # in_channels and out_channels
            if is_first_block:
                in_channels = base_filters
                out_channels = in_channels
            else:
                # increase filters at every self.increasefilter_gap blocks
                in_channels = int(base_filters*2**((i_block-1)//self.increasefilter_gap))
                if (i_block % self.increasefilter_gap == 0) and (i_block != 0):
                    out_channels = in_channels * 2
                else:
                    out_channels = in_channels
            
            tmp_block = BasicBlock(
                in_channels=in_channels, 
                out_channels=out_channels, 
                kernel_size=self.kernel_size, 
                stride = self.stride, 
                groups = self.groups, 
                downsample=downsample, 
                use_bn = self.use_bn, 
                use_do = self.use_do, 
                is_first_block=is_first_block)
            self.basicblock_list.append(tmp_block)

        # final prediction
        self.final_bn = nn.BatchNorm1d(out_channels)
        self.final_relu = nn.ReLU(inplace=True)
        # self.do = nn.Dropout(p=0.5)
        self.dense = nn.Linear(out_channels, n_classes)
        # self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        
        out = x
        
        # first conv
        if self.verbose:
            print('input shape', out.shape)
        out = self.first_block_conv(out)
        if self.verbose:
            print('after first conv', out.shape)
        if self.use_bn:
            out = self.first_block_bn(out)
        out = self.first_block_relu(out)
        
        # residual blocks, every block has two conv
        for i_block in range(self.n_block):
            net = self.basicblock_list[i_block]
            if self.verbose:
                print('i_block: {0}, in_channels: {1}, out_channels: {2}, downsample: {3}'.format(i_block, net.in_channels, net.out_channels, net.downsample))
            out = net(out)
            if self.verbose:
                print(out.shape)

        # final prediction
        if self.use_bn:
            out = self.final_bn(out)
        out = self.final_relu(out)
        out = torch.mean(out, -1)
        if self.verbose:
            print('final pooling', out.shape)
        # out = self.do(out)
        out = self.dense(out)
        if self.verbose:
            print('dense', out.shape)
        # out = self.softmax(out)
        if self.verbose:
            print('softmax', out.shape)
        
        return out

here is the error traceback:

[W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnBatchNormBackward. Traceback of forward call that caused the error:
  File "<string>", line 1, in <module>
  File "/usr/lib64/python3.6/multiprocessing/spawn.py", line 105, in spawn_main
    exitcode = _main(fd)
  File "/usr/lib64/python3.6/multiprocessing/spawn.py", line 118, in _main
    return self._bootstrap()
  File "/usr/lib64/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib64/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/alto/nima/textAnomaly/train_encoder_dd.py", line 168, in train
    h1 = h_net(x1_rep)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 619, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/alto/nima/textAnomaly/resent1D.py", line 275, in forward
    out = self.final_bn(out)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 136, in forward
    self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/functional.py", line 2058, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
 (function _print_stack)
  0%|                                                                                                                                                                                                                                | 0/235 [00:03<?, ?it/s]
[W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnBatchNormBackward. Traceback of forward call that caused the error:
  File "<string>", line 1, in <module>
  File "/usr/lib64/python3.6/multiprocessing/spawn.py", line 105, in spawn_main
    exitcode = _main(fd)
  File "/usr/lib64/python3.6/multiprocessing/spawn.py", line 118, in _main
    return self._bootstrap()
  File "/usr/lib64/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib64/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/alto/nima/textAnomaly/train_encoder_dd.py", line 168, in train
    h1 = h_net(x1_rep)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 619, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/alto/nima/textAnomaly/resent1D.py", line 275, in forward
    out = self.final_bn(out)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 136, in forward
    self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/functional.py", line 2058, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
 (function _print_stack)
[W python_anomaly_mode.cpp:104] Warning: Error detected in CudnnBatchNormBackward. Traceback of forward call that caused the error:
  File "<string>", line 1, in <module>
  File "/usr/lib64/python3.6/multiprocessing/spawn.py", line 105, in spawn_main
    exitcode = _main(fd)
  File "/usr/lib64/python3.6/multiprocessing/spawn.py", line 118, in _main
    return self._bootstrap()
  File "/usr/lib64/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib64/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/alto/nima/textAnomaly/train_encoder_dd.py", line 168, in train
    h1 = h_net(x1_rep)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 619, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/alto/nima/textAnomaly/resent1D.py", line 275, in forward
    out = self.final_bn(out)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 136, in forward
    self.weight, self.bias, bn_training, exponential_average_factor, self.eps)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/functional.py", line 2058, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
 (function _print_stack)
  0%|                                                                                                                                                                                                                                | 0/235 [00:04<?, ?it/s]
  0%|                                                                                                                                                                                                                                | 0/235 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "train_encoder_dd.py", line 210, in <module>
    mp.spawn(train, nprocs=args.num_gpus, args=(args,))
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
    while not context.join():
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 118, in join
    raise Exception(msg)
Exception: 

-- Process 2 terminated with the following error:
Traceback (most recent call last):
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/alto/nima/textAnomaly/train_encoder_dd.py", line 174, in train
    loss.backward()
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1024]] is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
1 Like

Hey @nima_rafiee, which PyTorch release are you using? Can you try v1.7+ if you are not using it? This bug is likely fixed by: [v1.7] Quick fix for view/inplace issue with DDP by albanD · Pull Request #46407 · pytorch/pytorch · GitHub

@mrshenli thanks for your reply. I’m using 1.71.

Hey @nima_rafiee, could you please share a self-contained repro, including how DDP was called?

cc @albanD have you seen similar errors recently? Since the same model passed in local training, looks like the only difference is the scatter operator in DDP’s forward function. But I recall it was fixed by #46407 in v1.7?

1 Like

@mrshenli here is the code



import wget
import os
os.environ["CUDA_VISIBLE_DEVICES"]= "1,2,3,4,5,6"
import argparse
from tqdm import tqdm
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.distributed as dist
import torch.multiprocessing as mp
from apex import amp
from resent1D import ResNet1D
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
import numpy as np
vis = True
try:
   from torch.utils.tensorboard import SummaryWriter
except:
   vis = False 
   
wt = 0.001




def train(rank, args):
   print("Passed GPU:" ,rank)
   dist.init_process_group(backend='nccl',init_method='env://',world_size=args.num_gpus ,rank=rank)    
   torch.cuda.set_device(rank)
   exp_num = f'h_{args.h_dim}_bs{args.bs}_hlr{args.h_lr}' 
   if vis == True :
       args.writer = SummaryWriter(f'runs/{exp_num}')  

   h_net = ResNet1D( in_channels=768, 
       base_filters=128, # 64 for ResNet1D, 352 for ResNeXt1D
       kernel_size=16, 
       stride=2, 
       groups=32, 
       n_block=48, 
       n_classes=args.h_dim, 
       downsample_gap=6, 
       increasefilter_gap=12, 
       use_do=True)

   h_net_opt = optim.Adam(h_net.parameters(), lr=args.h_lr, weight_decay=args.weight_decay)    
   
   h_net.cuda(rank) 
   model_lst , h_net_opt = amp.initialize([h_net,], h_net_opt, opt_level="O1")
   
   h_net = model_lst[0]
   h_net = DistributedDataParallel(h_net, device_ids=[rank], find_unused_parameters=True)    
   

   ds = TextDataset(dataset_name='AG_NEWS', out_cls=[])
   sampler = DistributedSampler(ds, num_replicas=args.num_gpus)
   loader = DataLoader(ds, batch_size=args.bs, sampler=sampler, shuffle=False)    
   
   
   loader = tqdm(loader)
   for epoch in range(args.epoch):
       sampler.set_epoch(epoch)        
       for i , (x1 , x2 , x_rnd, label) in enumerate(loader):
           h_net_opt.zero_grad()           
           h1 = h_net(x1)
           h2 = h_net(x2)
               
           loss = loss(h1, h2, temprature = args.temprature )             
           loss.backward()
           h_net_opt.step()
           loader.set_description(
               (
                   f' Epoch: {epoch + 1};  iter: {i} Loss: {loss.item()} '         
               )
           )

       if vis == True  :
           with torch.no_grad(): 
               args.writer.add_scalar("Loss", loss.item(), global_step=epoch, walltime=wt) 

       if rank == 0:
           #model_to_save = h_net.module if hasattr(h_net, 'module') else h_net  # Take care of distributed/parallel training
           torch.save(h_net.state_dict(), 'checkpoint/dd/h_net.checkpoint')
               
   args.writer.close()
   dist.destroy_process_group()
if __name__ == "__main__":
   
   parser = argparse.ArgumentParser()
   parser.add_argument('--bs', type=int, default=32)
   parser.add_argument('--epoch', type=int, default=50)    
   parser.add_argument('--h_lr', type=float, default=3e-4)   
   parser.add_argument('--h_dim', type=int, default=128)    
   parser.add_argument('--weight_decay', type=float, default=1e-6)    
   parser.add_argument('--temprature', type=float, default=0.5)        
      
   args = parser.parse_args()
   

   device_ids = list(range(torch.cuda.device_count()))
   gpus = len(device_ids)
   args.num_gpus = gpus 
       
   os.environ['MASTER_ADDR'] = '127.0.0.1'
   os.environ['MASTER_PORT'] = '8888'
   mp.spawn(train, nprocs=args.num_gpus, args=(args,))

@mrshenli I don’t think this is related to the view/inplace fix that is mentioned there.
There are no views involved in this example and the error is a “classic” change inplace of a Tensor that is needed for gradients. Since the failing Node is batchnorm, I guess that that would either be in the input to the Module or the weights of the affine transformation.

Is DDP modifying weights inplace by any chance? Or the optimizer is not used properly?

thanks for your reply. To my understanding, I could not find any typical problematic in-place operation in the network structure ( it’s working in non-distributed mode) I don’t know if there exist new issues with in-place operation in DDP mode. for the optimizer, I’m using the classic way mentioned in tutorials.

@mrshenli @albanD. I found this SyncBatchNorm — PyTorch 1.7.0 documentation . seems normal BN can not be used with DDP. but I dont know how to use this in my code. in specific I can not understand this line.

process_group = torch.distributed.new_group(process_ids)

I already initialised a process group using the:

dist.init_process_group(backend='nccl',init_method='env://',world_size=args.num_gpus ,rank=rank)    

why should I make a new_group() and how can I get process_ids

1 Like

Hi, @nima_rafiee,have you solved your problem? I have just met a similar question in this topic Using ‘DistributedDataParallel’ occurs a weird error - distributed - PyTorch Forums

@shaoming20798

Yes. my problem was the normal batch norm is not working with DDP so I replaced it with syncedbatchnorm

the way I used the synced batch norm

def train(rank, conf):
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '8888'
    dist.init_process_group(backend='nccl',init_method='env://',world_size=world_size ,rank=rank)  
    torch.cuda.set_device(rank)   
    process_group = torch.distributed.new_group()
    model = Resnet18()
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group)   
    model.cuda(rank) 

I hope this solves your problem

1 Like

@nima_rafiee I was trying to look into why the existing code didn’t work with BatchNorm. In the code you provided, there is the following import statement:

from resent1D import ResNet1D

Where can I find the resnet1D module?

here is the link:

@nima_rafiee Thanks for sharing the resnet1d dependencies, although I still can’t repro this locally since I’m now missing TextDataset:

NameError: name 'TextDataset' is not defined

Hi, thanks for posting your solution. It perfectly solves my problem.

But I’m still wondering why our situations, DDP doesn’t work with normal BN.

Both BN and SyncBN are supported by DDP and work well in my toy model:

class ToyModel(nn.Module):

    def __init__(self):

        super(ToyModel, self).__init__()

        self.net1 = nn.Linear(10, 10)

        self.relu = nn.ReLU(inplace=True)

        self.bn = nn.BatchNorm1d(10)

        self.net2 = nn.Linear(10, 5)

    def forward(self, x):

        return self.net2(self.bn(self.relu(self.net1(x))))

Found the cause in my case: Inplace error if DistributedDataParallel module that contains a buffer is called twice · Issue #22095 · pytorch/pytorch · GitHub

1 Like

Using SyncBatchNorm in favor of BatchNorm worked perfectly in my case. Thank you for the advice.

1 Like