LLaMA 2 training inplace operation

I tried to fine tune my llama 2 model, but it shows an error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [32, 90, 128]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead.
The code is from llama github repo, but I removed all @torch.inference_mode to get the grad_fn. This is the forward function of which the problem is from.:

def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        """
        Forward pass of the attention module.

        Args:
            x (torch.Tensor): Input tensor.
            start_pos (int): Starting position for caching.
            freqs_cis (torch.Tensor): Precomputed frequency tensor.
            mask (torch.Tensor, optional): Attention mask tensor.

        Returns:
            torch.Tensor: Output tensor after attention.

        """
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(keys, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

And this is the full traceback of what’s wrong (torch.autograd.set_detect_anomaly(True) is executed):

/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:251: UserWarning: Error detected in BmmBackward0. Traceback of forward call that caused the error:
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 1053, in launch_instance
    app.start()
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 737, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 524, in dispatch_queue
    await self.process_one()
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 513, in process_one
    await dispatch(*args)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 418, in dispatch_shell
    await result
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py", line 758, in execute_request
    reply_content = await reply_content
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 426, in do_execute
    res = shell.run_cell(
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py", line 549, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3046, in run_cell
    result = self._run_cell(
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3101, in _run_cell
    result = runner(coro)
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3306, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3488, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py", line 3548, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_1760/2845453657.py", line 2, in <module>
    train_loop(
  File "/tmp/ipykernel_1760/31845858.py", line 29, in train_loop
    gen_train, _, prob_ratio = gen_model(emb_in[:, :i])  # gen_train has the shape of (batch_size,) and is int
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_1760/1024606268.py", line 14, in forward
    gen_train, probs_train = self.gen_llm(transed, 3000, temperature=0)  # generated next token, for training
  File "/workspace/Music2Story/epic_TEXTgen/llama/generation.py", line 281, in generate_next_token
    logits = self.model.forward(embeds[:, prev_pos:cur_pos], prev_pos)
  File "/workspace/Music2Story/epic_TEXTgen/llama/model.py", line 490, in forward
    h = layer(h, start_pos, freqs_cis, mask)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workspace/Music2Story/epic_TEXTgen/llama/model.py", line 406, in forward
    h = x + self.attention.forward(
  File "/workspace/Music2Story/epic_TEXTgen/llama/model.py", line 302, in forward
    output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[35], line 2
      1 print(
----> 2     train_loop(
      3         train_loader=train_dataloader,
      4         gen_model=fg,
      5         audio_to_emb_model=mm,
      6         optimizer=adam,
      7         loss_fn=lgppo,
      8         epoch=5,
      9         device=device,
     10         batch_size=3,
     11     )
     12 )

Cell In[34], line 39, in train_loop(train_loader, gen_model, audio_to_emb_model, optimizer, loss_fn, epoch, device, batch_size)
     36     pass
     38 loss, rewards = loss_fn([[j for j in i if j != -1] for i in generated_tokens], audios, prob_ratio)
---> 39 loss.backward()
     40 losses.append(loss.item())
     41 optimizer.step()

File /usr/local/lib/python3.10/dist-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File /usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that
    249 # some Python versions print out the first line of a multi-line function
    250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [32, 90, 128]], which is output 0 of AsStridedBackward0, is at version 4; expected version 3 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I don’t exactly know why can matmul causes in place operation. I’m very confused. Thank anybody that can help me!

It looks like the problem is probably the kv_cache:

Since the tensor is modified in-place, you wouldn’t want to back-propagate through it. You wouldn’t want to anyway though, since a kv cache is only used in inference. This repo doesn’t look like it’s designed to run without kv caching, looks like an inference-only implementation.

There are many implementations of llama. If your goal is fine-tuning, you might want to try a different one.

You might also want to look at llama-recipes, which also includes a fine tuning script: https://github.com/facebookresearch/llama-recipes/blob/main/src/llama_recipes/finetuning.py