Llama-2 CUDA OOM during inference but not training

It looks like I’m leaking CUDA memory during inference but not during training. It might have to do with the custom model that I’ve implemented but I don’t know how to debug it.

My custom model is a SiameseLlama, meaning it’s a Llama-2 model that has two heads: One language modeling head and one classification head. Both heads share the same base Llama-2 backbone. In practice, this is implemented as initializing a LlamaForSequenceClassification and a LlamaForCausalLM as two separate models, deleting the backbone model of the classifier and linking the language model’s backbone to the classifier by doing this:

del classifier.model
gc.collect()
classifier.model = lm.model

This is my SiameseLlama model file:

from utils.metrics import WeightedKappa
from configs.training import TrainConfig

from dataclasses import dataclass, field
import os
import torch
from transformers import (
    LlamaForSequenceClassification,
    LlamaForCausalLM,
    PreTrainedModel,
    BitsAndBytesConfig,
)
from transformers.modeling_outputs import (
    CausalLMOutputWithPast,
    SequenceClassifierOutputWithPast,
)
from typing import Callable


type DataDict = dict[str, torch.Tensor]


@dataclass
class SiameseOutput:
    classifier_output: SequenceClassifierOutputWithPast | None = None
    lm_output: CausalLMOutputWithPast | None = None
    loss: SiameseLoss | None = None
    preds: DataDict = field(default_factory=dict)


class SiameseLoss(torch.nn.Module):
    def __init__(
        self,
        num_classes: int = 5,
        score_feedback_ratio: tuple[float, float] = (0.5, 0.5),
        scoring_loss: str = "cross_entropy",
    ) -> None:
        """
        `scoring_loss`: {"cross_entropy", "weighted_kappa", "quadratic_weighted_kappa"}
        """
        super().__init__()
        if scoring_loss not in (
            valid_scoring_losses := {
                "cross_entropy",
                "linear_weighted_kappa",
                "quadratic_weighted_kappa",
            }
        ):
            raise AssertionError(f"`scoring_loss` must be in {str(valid_scoring_losses)}")
        self.scoring_loss = scoring_loss
        if self.scoring_loss != "cross_entropy":
            self.kappa_weighting = self.scoring_loss.split("_")[0]
        self.num_classes = num_classes
        # norm the score_feedback_ratio weights between 0 and 1 if they weren't already
        self.score_feedback_ratio: tuple = tuple(
            w / sum(score_feedback_ratio) for w in score_feedback_ratio
        )

    def forward(
        self,
        classifier_output: SequenceClassifierOutputWithPast | None = None,
        lm_output: CausalLMOutputWithPast | None = None,
        classifier_labels: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        `labels` can be left as None if `scoring_loss == "cross_entropy"`.
        Returns a torch.Tensor with a single value, can be backwarded.
        """
        if not (classifier_output or lm_output):
            raise AssertionError("Missing `classifier_output` and/or `lm_output`.")

        if classifier_output and not classifier_labels:
            raise AssertionError(
                "When passing the classifier output, also have to pass the true labels."
            )

        if not classifier_output:
            classifier_loss = 0
        elif self.scoring_loss == "cross_entropy":
            if classifier_output.loss is None:
                raise ValueError(
                    "`classifier_output` does not have a `.loss` attribute. This probably means that the classification head did not receive the labels during its forward pass."
                )
            classifier_loss = classifier_output.loss
        else:  # weighted kappa loss
            weighted_kappa = WeightedKappa(
                weighting=self.kappa_weighting, num_classes=self.num_classes, as_loss=True
            )
            classifier_loss = weighted_kappa(output=classifier_output, labels=classifier_labels)

        if lm_output:
            if not lm_output.loss:
                raise ValueError(
                    "`lm_output` does not have a `.loss` attribute. This probably means that the LM head did not receive the labels during its forward pass."
                )
            lm_loss = lm_output.loss
        else:
            lm_loss = 0

        siamese_loss = (
            classifier_loss * self.score_feedback_ratio[0] + lm_loss * self.score_feedback_ratio[1]
        )
        assert siamese_loss != 0

        return siamese_loss


class SiameseLlama:
    def __init__(
        self,
        models: dict[str, PreTrainedModel | None],
        prompting_strategy: str = "vanilla",
        score_feedback_ratio: tuple[float, float] = (0.5, 0.5),
        scoring_loss: str = "cross_entropy",
    ):
        if not models["causal_lm"] and not models["classifier"]:
            raise ValueError(
                "Must pass at least a `causal_lm` or `classifier` in the `models` dict."
            )

        self.score_feedback_ratio: tuple[float, float] = score_feedback_ratio
        self.scoring_loss: str = scoring_loss
        self.prompting_strategy: str = prompting_strategy
        self.tasks: list[str] = self.prompting_strategy.split("_then_")

        self.models = models
        for model in self.models.values():
            if model:
                self.parameters: Callable = model.parameters
                self.resize_token_embeddings: Callable = model.resize_token_embeddings
                # a bool that's True if model is in training mode, False if in eval mode
                self.training: bool = model.training
                break

    def train(self) -> None:
        self.training = True
        for model in self.models.values():
            if model:
                model.train()

    def eval(self) -> None:
        self.training = False
        for model in self.models.values():
            if model:
                model.eval()

    def to(self, device, *args, **kwargs) -> None:
        for model in self.models.values():
            if model:
                model.to(device)

    def multitask_forward(
        self,
        datapoint: dict[str, DataDict],
        **kwargs,
    ) -> SiameseOutput:
        """
        Returns a loss if labels were given. Otherwise, returns a list of huggingface output objects.
        """

        siamese_output = SiameseOutput()

        for task in datapoint.keys():
            task_data: DataDict = datapoint[task]
            for key in task_data:
                task_data[key] = task_data[key].to("cuda")

            if task == "score":
                task_output = self.models["classifier"](**task_data, **kwargs)
                siamese_output.classifier_output = task_output

            else:  # task == "feedback"
                task_output = self.models["causal_lm"](**task_data, **kwargs)
                siamese_output.lm_output = task_output

            siamese_output.preds[task] = torch.argmax(task_output.logits, dim=-1)

        siamese_loss = SiameseLoss(
            score_feedback_ratio=self.score_feedback_ratio, scoring_loss=self.scoring_loss
        )
        siamese_output.loss = siamese_loss(
            classifier_output=siamese_output.classifier_output,
            lm_output=siamese_output.lm_output,
            classifier_labels=datapoint.get("score", {}).get("labels", None),
        )

        return siamese_output

I run my training and eval script with this model while tracking CUDA stats. This is the output during the very first evaluation (before doing any training). Please note how the memory stats increase from eval step 0 to eval step 1. Eval step 0 is completed successfully but memory runs out during step 1 (?).

STEP 0
CUDA memory at the start: 4 GB
Max CUDA memory allocated was 8 GB
Max CUDA memory reserved was 9 GB
Peak active CUDA memory was 8 GB
Cuda Malloc retries : 0
CPU Total Peak Memory consumed (max): 1 GB

evaluating Epoch:  50%|e[32m█████     e[0m| 1/2 [00:11<00:11, 11.92s/it]CUDA memory at the start: 8 GB
STEP 1
CUDA memory at the start: 8 GB
Max CUDA memory allocated was 9 GB
Max CUDA memory reserved was 10 GB
Peak active CUDA memory was 9 GB
Cuda Malloc retries : 1
CPU Total Peak Memory consumed (max): 2 GB

evaluating Epoch:  50%|e[32m█████     e[0m| 1/2 [00:15<00:15, 15.32s/it]
╭───────────────────── Traceback (most recent call last) ──────────────────────╮
... 
...                                                                              │                                                                              │
│ /workspace/students/mai/master/llm-feedbacks/train/utils/train_utils.py:342  │
│ in evaluation                                                                │
│                                                                              │
│   339 │   │   │   # with torch.no_grad():                                    │
│   340 │   │   │   │   # Forward pass and compute loss                        │
│   341 │   │   │   │   # outputs = model(**batch)                             │
│ ❱ 342 │   │   │   │   outputs = model.multitask_forward(batch)               │
│   343 │   │   │   │   loss = outputs.loss                                    │
│   344 │   │   │   │   eval_loss += loss.detach().float()                     │
│   345 │   │   │   # Decode predictions and add to evaluation predictions lis │
│                                                                              │
│ /workspace/students/mai/master/llm-feedbacks/train/models/siamese_llama.py:1 │
│ 66 in multitask_forward                                                      │
│                                                                              │
│   163 │   │   │   │   task_data[key] = task_data[key].to("cuda")             │
│   164 │   │   │                                                              │
│   165 │   │   │   if task == "score":                                        │
│ ❱ 166 │   │   │   │   task_output = self.models["classifier"](**task_data, * │
│   167 │   │   │   │   siamese_output.classifier_output = task_output         │
│   168 │   │   │                                                              │
│   169 │   │   │   else:  # task == "feedback"                                │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl                         │
│                                                                              │
│   1529 │   │   if self._compiled_call_impl is not None:                      │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: │
│   1531 │   │   else:                                                         │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                   │
│   1533 │                                                                     │
│   1534 │   def _call_impl(self, *args, **kwargs):                            │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_s │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl                                 │
│                                                                              │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                      │
│   1542 │   │                                                                 │
│   1543 │   │   try:                                                          │
│   1544 │   │   │   result = None                                             │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward                                        │
│                                                                              │
│   163 │   │   │   with torch.no_grad():                                      │
│   164 │   │   │   │   output = module._old_forward(*args, **kwargs)          │
│   165 │   │   else:                                                          │
│ ❱ 166 │   │   │   output = module._old_forward(*args, **kwargs)              │
│   167 │   │   return module._hf_hook.post_forward(module, output)            │
│   168 │                                                                      │
│   169 │   # Overriding a GraphModuleImpl forward freezes the forward call an │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:1352 in forward                   │
│                                                                              │
│   1349 │   │   """                                                           │
│   1350 │   │   return_dict = return_dict if return_dict is not None else sel │
│   1351 │   │                                                                 │
│ ❱ 1352 │   │   transformer_outputs = self.model(                             │
│   1353 │   │   │   input_ids,                                                │
│   1354 │   │   │   attention_mask=attention_mask,                            │
│   1355 │   │   │   position_ids=position_ids,                                │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl                         │
│                                                                              │
│   1529 │   │   if self._compiled_call_impl is not None:                      │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: │
│   1531 │   │   else:                                                         │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                   │
│   1533 │                                                                     │
│   1534 │   def _call_impl(self, *args, **kwargs):                            │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_s │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl                                 │
│                                                                              │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                      │
│   1542 │   │                                                                 │
│   1543 │   │   try:                                                          │
│   1544 │   │   │   result = None                                             │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward                                        │
│                                                                              │
│   163 │   │   │   with torch.no_grad():                                      │
│   164 │   │   │   │   output = module._old_forward(*args, **kwargs)          │
│   165 │   │   else:                                                          │
│ ❱ 166 │   │   │   output = module._old_forward(*args, **kwargs)              │
│   167 │   │   return module._hf_hook.post_forward(module, output)            │
│   168 │                                                                      │
│   169 │   # Overriding a GraphModuleImpl forward freezes the forward call an │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:968 in forward                    │
│                                                                              │
│    965 │   │   │   │   │   cache_position,                                   │
│    966 │   │   │   │   )                                                     │
│    967 │   │   │   else:                                                     │
│ ❱  968 │   │   │   │   layer_outputs = decoder_layer(                        │
│    969 │   │   │   │   │   hidden_states,                                    │
│    970 │   │   │   │   │   attention_mask=causal_mask,                       │
│    971 │   │   │   │   │   position_ids=position_ids,                        │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl                         │
│                                                                              │
│   1529 │   │   if self._compiled_call_impl is not None:                      │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: │
│   1531 │   │   else:                                                         │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                   │
│   1533 │                                                                     │
│   1534 │   def _call_impl(self, *args, **kwargs):                            │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_s │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl                                 │
│                                                                              │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                      │
│   1542 │   │                                                                 │
│   1543 │   │   try:                                                          │
│   1544 │   │   │   result = None                                             │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward                                        │
│                                                                              │
│   163 │   │   │   with torch.no_grad():                                      │
│   164 │   │   │   │   output = module._old_forward(*args, **kwargs)          │
│   165 │   │   else:                                                          │
│ ❱ 166 │   │   │   output = module._old_forward(*args, **kwargs)              │
│   167 │   │   return module._hf_hook.post_forward(module, output)            │
│   168 │                                                                      │
│   169 │   # Overriding a GraphModuleImpl forward freezes the forward call an │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:727 in forward                    │
│                                                                              │
│    724 │   │   # Fully Connected                                             │
│    725 │   │   residual = hidden_states                                      │
│    726 │   │   hidden_states = self.post_attention_layernorm(hidden_states)  │
│ ❱  727 │   │   hidden_states = self.mlp(hidden_states)                       │
│    728 │   │   hidden_states = residual + hidden_states                      │
│    729 │   │                                                                 │
│    730 │   │   outputs = (hidden_states,)                                    │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl                         │
│                                                                              │
│   1529 │   │   if self._compiled_call_impl is not None:                      │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: │
│   1531 │   │   else:                                                         │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                   │
│   1533 │                                                                     │
│   1534 │   def _call_impl(self, *args, **kwargs):                            │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_s │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl                                 │
│                                                                              │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                      │
│   1542 │   │                                                                 │
│   1543 │   │   try:                                                          │
│   1544 │   │   │   result = None                                             │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward                                        │
│                                                                              │
│   163 │   │   │   with torch.no_grad():                                      │
│   164 │   │   │   │   output = module._old_forward(*args, **kwargs)          │
│   165 │   │   else:                                                          │
│ ❱ 166 │   │   │   output = module._old_forward(*args, **kwargs)              │
│   167 │   │   return module._hf_hook.post_forward(module, output)            │
│   168 │                                                                      │
│   169 │   # Overriding a GraphModuleImpl forward freezes the forward call an │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ ransformers/models/llama/modeling_llama.py:216 in forward                    │
│                                                                              │
│    213 │   │   │   ]                                                         │
│    214 │   │   │   down_proj = sum(down_proj)                                │
│    215 │   │   else:                                                         │
│ ❱  216 │   │   │   down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) │
│    217 │   │                                                                 │
│    218 │   │   return down_proj                                              │
│    219                                                                       │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1532 in _wrapped_call_impl                         │
│                                                                              │
│   1529 │   │   if self._compiled_call_impl is not None:                      │
│   1530 │   │   │   return self._compiled_call_impl(*args, **kwargs)  # type: │
│   1531 │   │   else:                                                         │
│ ❱ 1532 │   │   │   return self._call_impl(*args, **kwargs)                   │
│   1533 │                                                                     │
│   1534 │   def _call_impl(self, *args, **kwargs):                            │
│   1535 │   │   forward_call = (self._slow_forward if torch._C._get_tracing_s │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/nn/modules/module.py:1541 in _call_impl                                 │
│                                                                              │
│   1538 │   │   if not (self._backward_hooks or self._backward_pre_hooks or s │
│   1539 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hoo │
│   1540 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks │
│ ❱ 1541 │   │   │   return forward_call(*args, **kwargs)                      │
│   1542 │   │                                                                 │
│   1543 │   │   try:                                                          │
│   1544 │   │   │   result = None                                             │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/a │
│ ccelerate/hooks.py:166 in new_forward                                        │
│                                                                              │
│   163 │   │   │   with torch.no_grad():                                      │
│   164 │   │   │   │   output = module._old_forward(*args, **kwargs)          │
│   165 │   │   else:                                                          │
│ ❱ 166 │   │   │   output = module._old_forward(*args, **kwargs)              │
│   167 │   │   return module._hf_hook.post_forward(module, output)            │
│   168 │                                                                      │
│   169 │   # Overriding a GraphModuleImpl forward freezes the forward call an │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/b │
│ itsandbytes/nn/modules.py:468 in forward                                     │
│                                                                              │
│   465 │   │   │   x = x.to(self.compute_dtype)                               │
│   466 │   │                                                                  │
│   467 │   │   bias = None if self.bias is None else self.bias.to(self.comput │
│ ❱ 468 │   │   out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_sta │
│   469 │   │                                                                  │
│   470 │   │   out = out.to(inp_dtype)                                        │
│   471                                                                        │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/b │
│ itsandbytes/autograd/_functions.py:579 in matmul_4bit                        │
│                                                                              │
│   576 │   │   │   │   out += bias                                            │
│   577 │   │   │   return out                                                 │
│   578 │   else:                                                              │
│ ❱ 579 │   │   return MatMul4Bit.apply(A, B, out, bias, quant_state)          │
│   580                                                                        │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/t │
│ orch/autograd/function.py:598 in apply                                       │
│                                                                              │
│   595 │   │   if not torch._C._are_functorch_transforms_active():            │
│   596 │   │   │   # See NOTE: [functorch vjp and autograd interaction]       │
│   597 │   │   │   args = _functorch.utils.unwrap_dead_wrappers(args)         │
│ ❱ 598 │   │   │   return super().apply(*args, **kwargs)  # type: ignore[misc │
│   599 │   │                                                                  │
│   600 │   │   if not is_setup_ctx_defined:                                   │
│   601 │   │   │   raise RuntimeError(                                        │
│                                                                              │
│ /home/students/mai/miniconda3/envs/prometheus/lib/python3.12/site-packages/b │
│ itsandbytes/autograd/_functions.py:509 in forward                            │
│                                                                              │
│   506 │   │                                                                  │
│   507 │   │   # 1. Dequantize                                                │
│   508 │   │   # 2. MatmulnN                                                  │
│ ❱ 509 │   │   output = torch.nn.functional.linear(A, F.dequantize_4bit(B, qu │
│   510 │   │                                                                  │
│   511 │   │   # 3. Save state                                                │
│   512 │   │   ctx.state = quant_state                                        │
╰──────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 172.00 MiB. GPU 
E0628 17:14:30.162000 22490380526784 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 17678) of binary: .../python
...

Important: I’m using:

  • a custom PyTorch training and evaluation loop, not the huggingface Trainer API.
  • bitsandbytes 4-bit quantization.

Things that I’ve tried:

  • Setting model.eval(). This seems to be causing the problem. When I replace it with model.train(), then memory stats are constant from step to step and I don’t run into memory issues.
  • gc.collect() and torch.cuda.empty_cache() before and after every eval step, using with torch.no_grad(). No positive effect.

Does anybody have any idea what might be causing the problem and/or how I could go about debugging it?

Since setting model.train() seems to work, would evaluating in training mode be an option if I can’t figure out a way to make this work in eval mode? If so, is there anything I have to keep in mind when doing that?

Please let me know if you need additional information or code. I haven’t gotten an actual minimal working example to work yet but if you need one, I’ll work on it.
Many thanks in advance, any comments would be much appreciated!

I think if you properly use

with torch.no_grad()

then it should work

@Tai-Mai if you can recreate the issue with a smaller example, it will be quicker to debug

A sample with a pretrained model

import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')


# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.to('cuda')

# Predict hidden states features for each layer
with torch.no_grad():
    outputs = model(tokens_tensor, token_type_ids=segments_tensors)
    encoded_layers = outputs[0]
    print("Max Memory Allocated {}".format(torch.cuda.max_memory_allocated()))
    (free, total) = torch.cuda.mem_get_info()
    print("USed Memory {}".format(total-free))

with torch.no_grad():
    outputs = model(tokens_tensor, token_type_ids=segments_tensors)
    encoded_layers = outputs[0]
    print("Max Memory Allocated {}".format(torch.cuda.max_memory_allocated()))
    (free, total) = torch.cuda.mem_get_info()
    print("USed Memory {}".format(total-free))

image

Thanks for the reply!
Unfortunately, I’m already using with torch.no_grad() for evaluation but it doesn’t have any positive effect. Training works fine and I don’t use with torch.no_grad() during training.

@Tai-Mai i see
Can you either

  • send the entire code along with data

  • Recreate with some transformer pre-trained model

Unfortunately, it’s a huge and confusing code base and I’m under contract so I’m not sure I can share the entire repository with you but maybe it helps if I show you my training and evaluation functions?

def train(
    model,
    train_dataloader,
    eval_dataloader,
    tokenizer,
    optimizer,
    lr_scheduler,
    gradient_accumulation_steps,
    train_config,
    fsdp_config=None,
    local_rank=None,
    rank=None,
    tensorboard=None,
):
    """
    Trains the model on the given dataloader

    Args:
        model: The model to be trained
        train_dataloader: The dataloader containing the training data
        optimizer: The optimizer used for training
        lr_scheduler: The learning rate scheduler
        gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
        num_epochs: The number of epochs to train for
        local_rank: The rank of the current node in a distributed setting
        train_config: The training configuration
        eval_dataloader: The dataloader containing the eval data
        tokenizer: tokenizer used in the eval for decoding the predicitons

    Returns: results dictionary containing average training and validation perplexity and loss
    """
    train_prep = []
    train_loss = []
    val_prep = []
    val_loss = []
    epoch_times = []
    checkpoint_times = []
    results = {}
    best_val_loss = float("inf")

    for epoch in range(train_config.num_epochs):
        ###
        # TODO: remove `not`
        if train_config.run_validation:
            eval_ppl, eval_epoch_loss = evaluation(
                model, train_config, eval_dataloader, local_rank, tokenizer
            )

        epoch_start_time = time.perf_counter()
        with MemoryTrace() as memtrace:  # track the memory usage
            model.train()
            model.prompting_strategy = train_config.prompting_strategy
            total_loss = 0.0
            for step, batch in enumerate(
                tqdm(train_dataloader, colour="blue", desc=f"Training Epoch {epoch}")
            ):
                # loss = model(**batch).loss
                # print(tokenizer.decode(batch["vanilla"]["input_ids"][0]))
                # output = model(**batch["vanilla"])
                output = model.multitask_forward(batch)

                loss = output.loss
                loss = loss / gradient_accumulation_steps
                total_loss += loss.detach().float()

                loss.backward()
                # TODO: Add argument for gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.8)
                if (step + 1) % gradient_accumulation_steps == 0 or step == len(
                    train_dataloader
                ) - 1:
                    optimizer.step()
                    optimizer.zero_grad()
                    if train_config.scheduler == "cosine":
                        lr_scheduler.step()
        epoch_end_time = time.perf_counter() - epoch_start_time
        epoch_times.append(epoch_end_time)
        # Reducing total_loss across all devices if there's more than one CUDA device
        train_epoch_loss: torch.Tensor = total_loss / len(train_dataloader)
        train_perplexity: float = torch.exp(train_epoch_loss).item()
        train_epoch_loss = train_epoch_loss.item()

        train_prep.append(train_perplexity)
        train_loss.append(train_epoch_loss)

        # Update the learning rate as needed
        # for stepLR only
        if train_config.scheduler == "step":
            lr_scheduler.step()

        if train_config.run_validation:
            eval_ppl, eval_epoch_loss = evaluation(
                model, train_config, eval_dataloader, local_rank, tokenizer
            )

            if tensorboard:
                # wandb_logger.log({"epoch":epoch,"eval_loss": total_loss,"learning_rate":lr_scheduler.get_lr()})
                tensorboard.add_scalar("Loss/eval", eval_epoch_loss, epoch)
                tensorboard.add_scalar("PPL/eval", eval_ppl, epoch)
                # tensorboard.add_scalar("Learning rate", optimizer.param_groups[-1]['lr'], epoch)

        checkpoint_start_time = time.perf_counter()
        # if train_config.save_model and eval_epoch_loss < best_val_loss:
        print(
            f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s"
        )
    avg_epoch_time = sum(epoch_times) / len(epoch_times)
    avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times)
    avg_train_prep = sum(train_prep) / len(train_prep)
    avg_train_loss = sum(train_loss) / len(train_loss)
    if train_config.run_validation:
        avg_eval_prep = sum(val_prep) / len(val_prep)
        avg_eval_loss = sum(val_loss) / len(val_loss)

    results["avg_train_prep"] = avg_train_prep
    results["avg_train_loss"] = avg_train_loss
    if train_config.run_validation:
        results["avg_eval_prep"] = avg_eval_prep
        results["avg_eval_loss"] = avg_eval_loss
    results["avg_epoch_time"] = avg_epoch_time
    results["avg_checkpoint_time"] = avg_checkpoint_time

    # saving the training params including fsdp setting for reference.
    # if train_config.enable_fsdp and not train_config.use_peft:
    #     save_train_params(train_config, fsdp_config, rank)

    return results


def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer):
    """
    Evaluates the model on the given dataloader

    Args:
        model: The model to evaluate
        eval_dataloader: The dataloader containing the evaluation data
        local_rank: The rank of the current node in a distributed setting
        tokenizer: The tokenizer used to decode predictions

    Returns: eval_ppl, eval_epoch_loss
    """
    # if train_config.enable_fsdp:
    #     world_size = int(os.environ["WORLD_SIZE"])
    # model.train()
    eval_preds = []
    eval_loss = 0.0  # Initialize evaluation loss
    with torch.no_grad():
        model.eval()
        # with MemoryTrace() as memtrace:
        for step, batch in enumerate(
            tqdm(eval_dataloader, colour="green", desc="evaluating Epoch")
        ):
            with MemoryTrace() as memtrace:
                print(f"STEP {step}")
                outputs = model.multitask_forward(batch)
                loss = outputs.loss
                eval_loss += loss.detach().float()
            for task, pred in outputs.preds.items():
                pred = pred.detach().cpu()
                if task == "score":
                    eval_preds.append(pred)
                else:  # task == "feedback"
                    eval_preds.append(tokenizer.batch_decode(pred, skip_special_tokens=True))

            # gc.collect()
            # torch.cuda.empty_cache()

    # Compute average loss and perplexity
    eval_epoch_loss: torch.Tensor = eval_loss / len(eval_dataloader)
    # if train_config.enable_fsdp:
    #     eval_epoch_loss: torch.Tensor = eval_epoch_loss / world_size
    eval_ppl: torch.Tensor = torch.exp(eval_epoch_loss)

    eval_epoch_loss: float = eval_epoch_loss.item()
    eval_ppl: float = eval_ppl.item()

    # Print evaluation metrics
    print(f" {eval_ppl=} {eval_epoch_loss=}")

    return eval_ppl, eval_epoch_loss

The important part is that the Siamese forward function (see the model file in my original question) is called with the following line:

output = model.multitask_forward(batch)

A batch looks like this:

{
    "score": {
        "input_ids": [...],
        "labels": [...],
        "attention_mask": [...],
    },
    "feedback": {
        "input_ids": [...],
        "labels": [...],
        "attention_mask": [...],
    }
}

The score dictionary contains the data for the classification task, the feedback dictionary contains the data for the causal language modeling task.
Is this perhaps of any use to you?
I’ll try to see what happens if I replace my SiameseLlama with a normal Llama and report back.

Yeah not much helpful. On further looking at your first post

  • There is definitely a forward call in your process which requires memory. Plain eval should not require a forward call

Plain eval should not require a forward call

I’m not sure I understand, could you elaborate? From my understanding, doing a forward call just means sending data through the model, typically by doing something like model(**batch) (in my case model.multitask_forward(batch)) and I obviously have to send my evaluation data through my model to do evaluation, am I mistaken? If you mean backward call, then I only do the backward call during training by doing loss.backward().

@Tai-Mai
I was able to re-create the error with your model

import utils
#from utils.metrics import WeightedKappa
#from configs.training import TrainConfig

from dataclasses import dataclass, field
import os
import torch
from transformers import (
    LlamaForSequenceClassification,
    LlamaForCausalLM,
    PreTrainedModel,
    BitsAndBytesConfig,
)
from transformers.modeling_outputs import (
    CausalLMOutputWithPast,
    SequenceClassifierOutputWithPast,
)
from typing import Callable


#type DataDict = dict[str, torch.Tensor]
DataDict = type(dict[str, torch.Tensor])

class SiameseLoss(torch.nn.Module):
    def __init__(
        self,
        num_classes: int = 5,
        score_feedback_ratio: tuple[float, float] = (0.5, 0.5),
        scoring_loss: str = "cross_entropy",
    ) -> None:
        """
        `scoring_loss`: {"cross_entropy", "weighted_kappa", "quadratic_weighted_kappa"}
        """
        super().__init__()
        if scoring_loss not in (
            valid_scoring_losses := {
                "cross_entropy",
                "linear_weighted_kappa",
                "quadratic_weighted_kappa",
            }
        ):
            raise AssertionError(f"`scoring_loss` must be in {str(valid_scoring_losses)}")
        self.scoring_loss = scoring_loss
        if self.scoring_loss != "cross_entropy":
            self.kappa_weighting = self.scoring_loss.split("_")[0]
        self.num_classes = num_classes
        # norm the score_feedback_ratio weights between 0 and 1 if they weren't already
        self.score_feedback_ratio: tuple = tuple(
            w / sum(score_feedback_ratio) for w in score_feedback_ratio
        )

    def forward(
        self,
        classifier_output: SequenceClassifierOutputWithPast | None = None,
        lm_output: CausalLMOutputWithPast | None = None,
        classifier_labels: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        `labels` can be left as None if `scoring_loss == "cross_entropy"`.
        Returns a torch.Tensor with a single value, can be backwarded.
        """
        if not (classifier_output or lm_output):
            raise AssertionError("Missing `classifier_output` and/or `lm_output`.")

        if classifier_output and not classifier_labels:
            raise AssertionError(
                "When passing the classifier output, also have to pass the true labels."
            )

        if not classifier_output:
            classifier_loss = 0
        elif self.scoring_loss == "cross_entropy":
            if classifier_output.loss is None:
                raise ValueError(
                    "`classifier_output` does not have a `.loss` attribute. This probably means that the classification head did not receive the labels during its forward pass."
                )
            classifier_loss = classifier_output.loss
        else:  # weighted kappa loss
            a=1
            #weighted_kappa = WeightedKappa(
            #    weighting=self.kappa_weighting, num_classes=self.num_classes, as_loss=True
            #)
            #classifier_loss = weighted_kappa(output=classifier_output, labels=classifier_labels)

        if lm_output:
            if not lm_output.loss:
                raise ValueError(
                    "`lm_output` does not have a `.loss` attribute. This probably means that the LM head did not receive the labels during its forward pass."
                )
            lm_loss = lm_output.loss
        else:
            lm_loss = 0

        siamese_loss = (
            classifier_loss * self.score_feedback_ratio[0] + lm_loss * self.score_feedback_ratio[1]
        )
        assert siamese_loss != 0

        return siamese_loss

@dataclass
class SiameseOutput:
    classifier_output: SequenceClassifierOutputWithPast | None = None
    lm_output: CausalLMOutputWithPast | None = None
    loss: SiameseLoss | None = None
    preds: DataDict = field(default_factory=dict)

class SiameseLlama:
    def __init__(
        self,
        models: dict[str, PreTrainedModel | None],
        prompting_strategy: str = "vanilla",
        score_feedback_ratio: tuple[float, float] = (0.5, 0.5),
        scoring_loss: str = "cross_entropy",
    ):
        if not models["causal_lm"] and not models["classifier"]:
            raise ValueError(
                "Must pass at least a `causal_lm` or `classifier` in the `models` dict."
            )

        self.score_feedback_ratio: tuple[float, float] = score_feedback_ratio
        self.scoring_loss: str = scoring_loss
        self.prompting_strategy: str = prompting_strategy
        self.tasks: list[str] = self.prompting_strategy.split("_then_")

        self.models = models
        for model in self.models.values():
            if model:
                self.parameters: Callable = model.parameters
                self.resize_token_embeddings: Callable = model.resize_token_embeddings
                # a bool that's True if model is in training mode, False if in eval mode
                self.training: bool = model.training
                break

    def train(self) -> None:
        self.training = True
        for model in self.models.values():
            if model:
                model.train()

    def eval(self) -> None:
        self.training = False
        for model in self.models.values():
            if model:
                model.eval()

    def to(self, device, *args, **kwargs) -> None:
        for model in self.models.values():
            if model:
                model.to(device)

    def multitask_forward(
        self,
        datapoint: dict[str, DataDict],
        **kwargs,
    ) -> SiameseOutput:
        """
        Returns a loss if labels were given. Otherwise, returns a list of huggingface output objects.
        """

        siamese_output = SiameseOutput()

        for task in datapoint.keys():
            task_data: DataDict = datapoint[task]
            for key in task_data:
                print("The current key is {}".format(key))
                task_data[key] = task_data[key].to("cuda")

            if task == "score":
                task_output = self.models["classifier"](**task_data, **kwargs)
                siamese_output.classifier_output = task_output

            else:  # task == "feedback"
                task_output = self.models["causal_lm"](**task_data, **kwargs)
                siamese_output.lm_output = task_output
            siamese_output.preds[task] = torch.argmax(task_output.logits, dim=-1)

        siamese_loss = SiameseLoss(
            score_feedback_ratio=self.score_feedback_ratio, scoring_loss=self.scoring_loss
        )
        siamese_output.loss = siamese_loss(
            classifier_output=siamese_output.classifier_output,
            lm_output=siamese_output.lm_output,
            classifier_labels=datapoint.get("score", {}).get("labels", None),
        )
        return siamese_output

from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
model = AutoModelForCausalLM.from_pretrained("bert-base-cased")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
siameseLlama = SiameseLlama(models = {"causal_lm": model})
siameseLlama.to('cuda')

from datasets import load_dataset
import torch
eli5 = load_dataset("eli5_category", split="train[0:10]")
eli5 = eli5.train_test_split(test_size=0.2)

def preprocess_function(examples):
    return tokenizer([" ".join(x["text"]) for x in examples["answers"]])
    
tokenized_eli5 = eli5.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=eli5["train"].column_names,
)

block_size = 128
def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result
lm_dataset = tokenized_eli5.map(group_texts, batched=True, num_proc=4)
import numpy as np
lm_dataset_datadict = {'input_ids':torch.tensor(np.array(lm_dataset['train']['input_ids'])),
                      'attention_mask':torch.tensor(np.array(lm_dataset['train']['attention_mask'])),
                      'labels':torch.tensor(np.array(lm_dataset['train']['labels']))
                    }
print(torch.cuda.mem_get_info())
siameseLlama.eval()
for i in range(4):
    print("Completed for epoch {}".format(i))
    with torch.no_grad():
        siameseLlama.multitask_forward({'train': lm_dataset_datadict})
    print(torch.cuda.mem_get_info())

image

Now to re-create your issue, i remove the no_grad clause

print(torch.cuda.mem_get_info())
siameseLlama.eval()
for i in range(4):
    print("Completed for epoch {}".format(i))
    siameseLlama.multitask_forward({'train': lm_dataset_datadict})
    print(torch.cuda.mem_get_info())

I am not sure there is anything more to it

I also tried calling the eval method of the model without your Siamese class and that also worked well

for i in range(4):
    print(torch.cuda.mem_get_info())
    with torch.no_grad():
        model.to('cuda')
        model(**lm_dataset_datadict)

Sorry about the late reply, it took me some time to find the solution to the issue but your reproduction was really helpful!

Turns out the solution was to manually delete the model output with del output after every evaluation step. This seems to fix the memory leak. I still don’t understand why this isn’t necessary during training, though.

Anyway, thanks again!