Weird inplace-operation-breaking-gradient when training with DDP

Hi, I’m facing a very weird error. In the above screenshot, when I pass the batched input through the network to compute the loss (dec_forward==“batch”), everything works fine. However, when I pass each individual sample through the network (dec_forward==“single”), I’ll be getting the gradient-computation-has-been-modified-by-an-inplace-operation error.

The traceback with torch.autograd.set_detect_anomaly(True) is shown below, which doesn’t really point to where the inplace operation takes place.

Traceback (most recent call last):
File “pretrain_new.py”, line 840, in
main(params)
File “/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py”, line 346, in wrapper
return f(*args, **kwargs)
File “pretrain_new.py”, line 500, in main
train_stats = train_one_epoch(
File “pretrain_new.py”, line 621, in train_one_epoch
loss.backward()
File “/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/_tensor.py”, line 487, in backward
torch.autograd.backward(
File “/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/autograd/init.py”, line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 3, 1, 1]] is at version 35; expected version 34 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Also, I’m now using DistributedDataParallel. I also tried dec_forward==“single” with a single GPU without DDP, and magically the training is fine without throwing errors.

I’m totally lost here and would appreciate any help/suggestions. Thank you! @ptrblck @albanD

Could you post a minimal and executable code snippet reproducing the error, please?

import argparse
import datetime
import json
import os
from pathlib import Path
import sys
import time
# import warnings


ROOT = "/".join(os.path.dirname(os.path.abspath(__file__)).split("/")[:-2])
sys.path.append(ROOT)
# isort: split

import numpy as np
import torch

torch.autograd.set_detect_anomaly(True)
import torch.backends.cudnn as cudnn
from torch.distributed.elastic.multiprocessing.errors import record
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import functional

# isort: split

# warnings.filterwarnings("ignore", category=DeprecationWarning)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# --------------------- utils ---------------------
def to_dist(model, local_rank):
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = nn.parallel.DistributedDataParallel(
        model, device_ids=[local_rank]
    )
    return model

def message_bce_loss(outputs, targets, margin=1.0):
    """
    Compute the message BCE loss.
    Args:
        margin: temperature of the sigmoid of the BCE loss
    """
    if torch.min(targets) == -1:
        targets = 0.5 * (targets + 1)

    return F.binary_cross_entropy_with_logits(
        outputs / margin, targets, reduction="mean"
    )

def generate_messages(n, k):
    """
    Generate random original messages.
    Args:
        n: Number of messages to generate
        k: length of the message
    Returns:
        msgs: boolean tensor of size nxk
    """
    return torch.rand((n, k)) > 0.5
# -------------------------------------------------

# --------------------- models ---------------------
class ConvBNRelu(nn.Module):
    """
    Building block used in HiDDeN network. Is a sequence of Convolution, Batch Normalization, and ReLU activation
    """

    def __init__(self, channels_in, channels_out, stride=1):
        super(ConvBNRelu, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(channels_in, channels_out, 3, stride, padding=1),
            nn.BatchNorm2d(channels_out),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)

def l2_normalization(x):
    l2_norm = torch.linalg.vector_norm(x, dim=(1, 2, 3)).view(-1, 1, 1, 1) # l2 norm of each image
    num_pixels = math.prod(x.size()[1:])
    return x / l2_norm * math.sqrt(num_pixels)


def delta_process(delta, last_process):
    if last_process == "tanh":
        return F.tanh(delta) 
    elif last_process == "l2":
        return l2_normalization(delta)
    elif last_process == "tanh-l2":
        return l2_normalization(F.tanh(delta))
    elif last_process == "linf":
        denom = torch.amax(torch.abs(delta), dim=(1, 2, 3), keepdim=True)[0] # max of each image
        return delta / denom
    elif last_process == "none":
        return delta
    else:
        raise NotImplementedError


class VanillaEncoder(nn.Module):
    """
    Inserts a watermark into an image.
    """
    def __init__(self, num_blocks, num_bits, channels, normalizer=None, last_process="tanh", output="w"):
        super(VanillaEncoder, self).__init__()
        layers = [ConvBNRelu(3, channels)]

        for _ in range(num_blocks - 1):
            layer = ConvBNRelu(channels, channels)
            layers.append(layer)

        self.conv_bns = nn.Sequential(*layers)
        self.after_concat_layer = ConvBNRelu(channels + 3 + num_bits, channels)
        self.final_layer = nn.Conv2d(channels, 3, kernel_size=1)

        assert last_process in ["tanh", "none", "l2", "linf", "tanh-l2"]
        self.last_process = last_process
        self.normalizer = normalizer
        assert output in ["w", "imgw"]
        self.output = output

    def forward(self, imgs, msgs):
        # take normalizer as part of the model
        # so that we can easily handle image pixel
        # value range in the training/testing script
        if self.normalizer is not None:
            imgs = self.normalizer(imgs)

        msgs = msgs.unsqueeze(-1).unsqueeze(-1)  # b l 1 1
        msgs = msgs.expand(-1, -1, imgs.size(-2), imgs.size(-1))  # b l h w

        encoded_image = self.conv_bns(imgs)

        concat = torch.cat([msgs, encoded_image, imgs], dim=1)
        residual = self.after_concat_layer(concat)
        residual = self.final_layer(residual) # b 3 h w

        if self.output == "w":
            return delta_process(residual, self.last_process)
        elif self.output == "imgw":
            return 0.5 * (delta_process(residual, self.last_process) + 1)  # [-1, 1] -> [0, 1]


class VanillaDecoder(nn.Module):
    """
    Decoder module. Receives a watermarked image and extracts the watermark.
    """
    def __init__(self, num_blocks, num_bits, channels, normalizer=None, redundancy=1):
        super(VanillaDecoder, self).__init__()

        layers = [ConvBNRelu(3, channels)]
        for _ in range(num_blocks - 1):
            layers.append(ConvBNRelu(channels, channels))

        layers.append(ConvBNRelu(channels, num_bits * redundancy))
        layers.append(nn.AdaptiveAvgPool2d(output_size=(1, 1)))
        self.layers = nn.Sequential(*layers)

        self.linear = nn.Linear(num_bits * redundancy, num_bits * redundancy)

        self.num_bits = num_bits
        self.redundancy = redundancy
        self.normalizer = normalizer

    def forward(self, img_w):
        # take normalizer as part of the model
        # so that we can easily handle image pixel
        # value range in the training/testing script
        if self.normalizer is not None:
            img_w = self.normalizer(img_w)

        x = self.layers(img_w)  # b d 1 1
        x = x.squeeze(-1).squeeze(-1)  # b d
        x = self.linear(x)

        x = x.view(-1, self.num_bits, self.redundancy)  # b k*r -> b k r
        x = torch.sum(x, dim=-1)  # b k r -> b k

        return x

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

# https://github.com/BorealisAI/advertorch/blob/master/advertorch/utils.py
class Normalizer(nn.Module):
    def __init__(self, mean=IMAGENET_MEAN, std=IMAGENET_STD):
        super(Normalizer, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean).view(1, 3, 1, 1)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std).view(1, 3, 1, 1)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, tensor):
        return tensor.sub(self.mean).div(self.std)

    def extra_repr(self):
        return "mean={}, std={}".format(self.mean, self.std)
# --------------------------------------------------



def main(dist=True):
    # Distributed mode
    if dist:
        local_rank = int(os.environ["LOCAL_RANK"])

        torch.distributed.init_process_group(
            init_method="env://",
            backend="nccl",
        )
        torch.cuda.set_device(local_rank)
        
        cudnn.benchmark = False
        cudnn.deterministic = True
    else:
        local_rank = 0

    # Set seeds for reproductibility
    seed = local_rank
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

    # Build encoder
    normalizer = Normalizer()
    encoder = VanillaEncoder(
        num_blocks=4,
        num_bits=32,
        channels=64,
        last_process="none",
        normalizer=normalizer,
        output="imgw",
    )

    # Build decoder
    print(">>> Building decoder...")
    decoder = VanillaDecoder(
        num_blocks=8,
        num_bits=32,
        channels=64,
        normalizer=normalizer,
    )

    # Adapt bn momentum
    all_modules = [*decoder.modules(), *encoder.modules()]
    for module in all_modules:
        if type(module) == torch.nn.BatchNorm2d:
            module.momentum = 0.01
    
    # Distributed training
    encoder.to(device)
    decoder.to(device)
    if dist:
        encoder = to_dist(encoder, local_rank)
        decoder = to_dist(decoder, local_rank)
            
    # Build optimizer and scheduler
    to_optim = [*encoder.parameters(), *decoder.parameters()]
    optimizer = torch.optim.AdamW(to_optim, lr=0.001)

    # just a forward pass
    encoder.train()
    decoder.train()
    
    imgs = 0.5 * torch.ones(32, 3, 128, 128).to(device, non_blocking=True)  # b c h w

    msgs_ori = generate_messages(imgs.shape[0], 32).to(
        device, non_blocking=True
    )  # n k
    msgs = 2 * msgs_ori.type(torch.float) - 1  # b k

    imgs_w = encoder(imgs, msgs).clamp(0, 1)  # b c h w

    outputs = []
    for ii, img_w in enumerate(imgs_w):
        outputs.append(decoder(img_w.unsqueeze(0)))
    outputs = torch.cat(outputs, dim=0)

    loss = message_bce_loss(outputs, msgs)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dist", action="store_true")
    args = parser.parse_args()
    
    main(args.dist)

Hi @ptrblck ! Please see above the code snippet; it’s still a bit long because of those model architecture definitions.

The command to reproduce the error is:

CUDA_VISIBLE_DEVICES="2,3" \
torchrun --nproc_per_node=2 --nnodes 1 --master_port 16000 pretrain_minimal.py --dist

which uses DDP, and the following error and traceback will be got:

/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/autograd/init.py:200: UserWarning: Error detected in DivBackward0. Traceback of forward call that caused the error:
File “pretrain_minimal.py”, line 295, in
main(args.dist)
File “pretrain_minimal.py”, line 281, in main
outputs.append(decoder(img_w.unsqueeze(0)))
File “/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/nn/parallel/distributed.py”, line 1156, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File “/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/nn/parallel/distributed.py”, line 1110, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore[index]
File “/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “pretrain_minimal.py”, line 173, in forward
img_w = self.normalizer(img_w)
File “/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 1501, in _call_impl
return forward_call(*args, **kwargs)
File “pretrain_minimal.py”, line 199, in forward
return tensor.sub(self.mean).div(self.std)
(Triggered internally at …/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File “pretrain_minimal.py”, line 295, in
main(args.dist)
File “pretrain_minimal.py”, line 286, in main
loss.backward()
File “/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/_tensor.py”, line 487, in backward
torch.autograd.backward(
File “/home/jz288/anaconda3/envs/sd2wm/lib/python3.8/site-packages/torch/autograd/init.py”, line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 3, 1, 1]] is at version 35; expected version 34 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

This time fortunately it showed that the error was with the input normalization tensor.sub(self.mean).div(self.std). However, I’m still confused as this is not really an inplace operation to my understanding.

Also, the following command without using DDP runs fine (no error):

CUDA_VISIBLE_DEVICES="2" python pretrain_minimal.py

And lastly, making the following changes at the end of main() (passing the whole batch rather than each single sample through the decoder) will also avoid the error:

#outputs = []
#for ii, img_w in enumerate(imgs_w):
#    outputs.append(decoder(img_w.unsqueeze(0)))
#outputs = torch.cat(outputs, dim=0)
outputs = decoder(imgs_w)

To summarize, the inplace-operation-breaking-gradient error only occurs when 1) DDP is used plus 2) each single input is passed through the model. Thank you for taking time to look into this! @ptrblck

Ok after multiple trials & errors, it seems that the issue is with registering the mean and standard deviation as buffer. If I don’t register as buffer but simply self.mean = mean and self.std = std the issue is gone.

That’s indeed weird and I cannot reproduce the issue. I’ve added a print("done") statement after the main(args.dist) call and get:

torchrun --nproc_per_node=2 tmp.py 
[2023-08-21 00:59:52,421] torch.distributed.run: [WARNING] 
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
>>> Building decoder...
>>> Building decoder...
done
done

without seeing any errors.