Pytorch transformer decoder inplace modified error (although I didn't use inplace operations..)

I am studying by designing a model structure using Transformer encoder and decoder.

I trained the classification model as a result of the encoder and trained the generative model with the decoder result (the result of the encoder as an input). Exports multiple results to output.

The following error occurred while learning:

I tracked the error using torch.autograd.set_detect_anomaly(True).

I saw an article about the same error on the PyTorch forum. However, they were mostly using inplace operations such as += or x[:, 0]=0. So it was solved when I fixed. But I didn’t use any of these operations. I tried to change unsqueeze() and squeeze() to view(), and also attach clone() to tensor maipulation. but error hasn’t be fixed. What is the problem?

model code

from pytorch_pretrained_bert import BertTokenizer, BertForSequenceClassification, BertForQuestionAnswering
from tqdm import tqdm
import pandas as pd

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class SelfAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(SelfAttention, self).__init__()

        self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads,

    def forward(self, x):
        query = x
        key = x
        value = x
        attn_output = self.multihead_attn(query, key, value, need_weights=False)

        return attn_output

class Encoder(nn.Module):
    def __init__(self, embedding_dim):
        super(Encoder, self).__init__()
        self.embedding_dim = embedding_dim
        # self.pos_encoder = PositionalEncoding()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.embedding_dim, nhead=8, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer=self.encoder_layer, num_layers=6)
        self.feedforward = nn.Linear(self.embedding_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.encoder(x)
        cls_out = torch.mean(out, dim=-2)
        cls_out = self.feedforward(cls_out)
        cls_out = self.sigmoid(cls_out)
        return out, cls_out

class Decoder(nn.Module):
    def __init__(self, embedding_dim):
        super(Decoder, self).__init__()
        # self.bert = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
        self.embedding_dim = embedding_dim
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=self.embedding_dim, nhead=8, batch_first=True)
        self.decoder = nn.TransformerDecoder(decoder_layer=self.decoder_layer, num_layers=6)

    def forward(self, tgt, memory):
        out = self.decoder(tgt, memory)

        return out

class AlzhBERT(nn.Module):
    def __init__(self, embedding_dim):
        super(AlzhBERT, self).__init__()
        self.embedding_dim = embedding_dim
        self.max_sent_length = 7

        self.token_level_attn = nn.ModuleList([SelfAttention(self.embedding_dim, num_heads=8) for _ in range(10)])
        self.token_level_attn_single = SelfAttention(self.embedding_dim, num_heads=8)
        self.sentence_level_attn = SelfAttention(self.embedding_dim, num_heads=8)

        self.encoder = Encoder(embedding_dim=embedding_dim)
        self.decoder = Decoder(embedding_dim=embedding_dim)

    def forward(self, X_batch):
        i = 0

        enc_outs = {}
        dec_outs = {}
        for datastruct in X_batch:
            enc_outs[i] = []
            dec_outs[i] = []
            for section in datastruct.sections:
                print(i, " + ", j)
                inv = section.inv.requires_grad_(True).to(device)
                y_dec = section.next_uttr.requires_grad_(True).to(device)
                par = section.par
                # print(par)
                    tmp = par.dim()
                except AttributeError:
                    print("attr err")
                    j = j+1

                # par = par.permute(1,0,2)                # (seq_len, sent_len, embed) => 한 번에 self attention
                # 여러개 self_attention
                # for p in par:
                result = self.token_level_attn_single([0]
                res = torch.mean(result, dim=-2).unsqueeze(0)

                res_sent = self.sentence_level_attn([0]
                context = torch.mean(res_sent, dim=-3)

                inv_input = torch.mean(inv, dim=-2)
                # x_enc = torch.concat((inv_input, context))
                # x_enc = x_enc.view([1, -1, self.embedding_dim])
                enc_out, cls_out = self.encoder(torch.concat([inv_input, context]).unsqueeze(0))
                # y_dec = torch.mean(y_dec, dim=-2).to(device)
                # enc_out = torch.mean(enc_out, dim=-2).unsqueeze(0).to(device)
                dec_out = self.decoder(y_dec,

                j = j+1

            enc_outs[i] = torch.tensor(enc_outs[i], requires_grad=True)
            i = i + 1

        return enc_outs, dec_outs

train code

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("device: ", device)

def train_loop(dataloader, model, loss_fn, optimizer, epochs):
    # dataloader = dataloader["train"]
    size = len(dataloader.dataset)
    writer = SummaryWriter()
    enc_optimizer = optimizer[0]
    dec_optimizer = optimizer[1]

    for epoch in range(epochs):
        enc_loss_hist = []
        dec_loss_hist = []
        accuracy = []

        print("======== epoch ", epoch, "==========\n")
        for i, (Xs, ys) in tqdm(enumerate(dataloader), desc="Train..."):
            X_folds,  y_folds = cross_validation(10, Xs, ys)

            for X, y in zip(X_folds['train'], y_folds['train']):                    # Xf는 DataStruct의 리스트임
                # print("<Check Data>")
                # print("X 0: ", X[0])
                # print("label 0: ", y[0])

                # Prediction and Loss
                # X = batch_to_tensor(X)
                # X = torch.tensor(X).to(device)
                y = torch.tensor(y, dtype=torch.float32).to(device)

                enc_preds, dec_preds = model(X)

                for k in range(len(X)):
                    for t in range(len(enc_preds[k])):
                        enc_loss = loss_fn(y[k].to(device), enc_preds[k][t].to(device)).requires_grad_(True)
                        dec_loss = loss_fn(X[k].sections[t], dec_preds[k][t].to(device)).requires_grad_(True)

                        cls_out = torch.tensor(1 if enc_preds[k][t] >= 0.5 else 0)
                        cls_loss = torch.sum(cls_out == y[k])

                        # Backpropagation




                cross_validation_loop(X_folds["valid"], y_folds["valid"], model, loss_fn, epoch)

        enc_loss_save = torch.mean(torch.tensor(enc_loss_hist))
        dec_loss_save = torch.mean(torch.tensor(dec_loss_hist))
        accuracy_save = torch.mean(torch.tensor(accuracy, dtype=torch.float32))

        writer.add_scalar("Avg Enc Loss/train", enc_loss_save, epoch)
        writer.add_scalar("Avg Dec Loss/train", dec_loss_save, epoch)
        writer.add_scalar("Avg Accuracy/train", accuracy_save)

        if device == "cuda":
            saved_model_dir = "/home/juny/AlzheimerModel/checkpoint"
            saved_model_dir = "./saved_model"

        now ={
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': enc_optimizer.state_dict(),
            'loss': [enc_loss_save, dec_loss_save],
        }, os.path.join('/home/juny/AlzheimerModel/checkpoint',
                        now.strftime("%Y-%m-%d-%H-%M") + "-e" + str(epoch) + ".pt")), os.path.join(saved_model_dir, "saved_model" + now.strftime("%Y-%m-%d-%H-%M") + ".pt"))
        encloss, decloss, current = enc_loss_save, dec_loss_save.item(), i * len(X)
        print(f"enc loss: {encloss:>7f} dec loss: {decloss:>7f} [{current:>5d}/{size:>5d}")


error message

C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\autograd\ UserWarning: Error detected in NativeLayerNormBackward0. Traceback of forward call that caused the error:
  File "C:/Users/usr/PycharmProjects/project/", line 265, in <module>
    train_loop(dataloader=train_dataloader, model=model, loss_fn=loss_fn, optimizer=(enc_optimizer, dec_optimizer), epochs=epochs)
  File "C:/Users/usr/PycharmProjects/project/", line 47, in train_loop
    enc_preds, dec_preds = model(X)
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\usr\PycharmProjects\project\", line 139, in forward
    dec_out = self.decoder(y_dec,
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\usr\PycharmProjects\project\", line 84, in forward
    out = self.decoder(tgt, memory)
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\", line 291, in forward
    output = mod(output, memory, tgt_mask=tgt_mask,
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\", line 578, in forward
    x = self.norm3(x + self._ff_block(x))
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\modules\", line 189, in forward
    return F.layer_norm(
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\nn\", line 2503, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
 (Triggered internally at  C:\cb\pytorch_1000000000000\work\torch\csrc\autograd\python_anomaly_mode.cpp:104.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Train...: 0it [00:05, ?it/s]
Traceback (most recent call last):
  File "C:/Users/usr/PycharmProjects/project/", line 265, in <module>
    train_loop(dataloader=train_dataloader, model=model, loss_fn=loss_fn, optimizer=(enc_optimizer, dec_optimizer), epochs=epochs)
  File "C:/Users/usr/PycharmProjects/project/", line 65, in train_loop
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "C:\Users\usr\anaconda3\envs\pytorch-python3.8\lib\site-packages\torch\autograd\", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [768]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Process finished with exit code 1

These errors are often raised when retain_graph=True is used while it’s not needed and sometimes added as a workaround for another error. Could you explain why retain_graph=True is used in your code?

I run into the same errors here. And I find out the main reasons may be the F.layer_norm, which preventing the retain_graph=True.

My case is to train a LLM with multiple loss. Since each loss may be additive and may be isolated, I decide to compute each loss and run the backwards separately, and do the step() after all losses are done. So I HAVE to keep the computational graph to save GPU RAMs.

Is there any way to avoid this error with retain_graph=True?

Your use case of calling backward multiple times without calling optimizer.step() between these calls is valid and should not raise the error.
Could you post a a minimal and executable code snippet reproducing this error, please?