CUDA out of memory while calculating DINO model loss/ using PYTORCH_CUDA_ALLOC_CONF

Hi, I’m trying to train a dino model (vit_base) on my own dataset, after passing the first epoch, at the first step of the second epoch I get an error:
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 23.65 GiB total capacity; 22.04 GiB already allocated; 3.56 MiB free; 22.42 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

i did some research on similar situations and tried to figure out what value i should set in PYTORCH_CUDA_ALLOC_CONF but didn’t find a solution


Epoch: [0/300]  [633/634]  eta: 0:00:01  loss: 10.988656 (11.053826)  lr: 0.000012 (0.000006)  wd: 0.040010 (0.040003)  time: 1.269988  data: 0.001112  max mem: 22511
Epoch: [0/300] Total time: 0:12:16 (1.161057 s / it)
Averaged stats: loss: 10.988656 (11.053826)  lr: 0.000012 (0.000006)  wd: 0.040010 (0.040003)
Garbage collected 896
Epoch: [1/300]  [  0/634]  eta: 2:03:45  loss: 10.966450 (10.966450)  lr: 0.000013 (0.000013)  wd: 0.040010 (0.040010)  time: 11.712133  data: 10.748226  max mem: 22511
submitit ERROR (2022-11-23 09:26:59,116) - Submitted job triggered an exception

it can be seen that the first epoch ends without problems, but at the first step of the second epoch I get an error

full error traceback:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.8/site-packages/submitit/core/_submit.py", line 11, in <module>
    submitit_main()
  File "/opt/conda/lib/python3.8/site-packages/submitit/core/submission.py", line 72, in submitit_main
    process_job(args.folder)
  File "/opt/conda/lib/python3.8/site-packages/submitit/core/submission.py", line 65, in process_job
    raise error
  File "/opt/conda/lib/python3.8/site-packages/submitit/core/submission.py", line 54, in process_job
    result = delayed.result()
  File "/opt/conda/lib/python3.8/site-packages/submitit/core/utils.py", line 133, in result
    self._result = self.function(*self.args, **self.kwargs)
  File "run_with_submitit.py", line 86, in __call__
    main_dino.train_dino(self.args)
  File "/WorkFolder/dino/main_dino.py", line 274, in train_dino
    train_stats = train_one_epoch(student, teacher, teacher_without_ddp, dino_loss,
  File "/WorkFolder/dino/main_dino.py", line 321, in train_one_epoch
    loss = dino_loss(student_output, teacher_output, epoch)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/WorkFolder/dino/main_dino.py", line 401, in forward
    loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 23.65 GiB total capacity; 22.04 GiB already allocated; 3.56 MiB free; 22.42 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

as we can see the reason of error is:
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)

and the full loss class:

class DINOLoss(nn.Module):
    def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
                 center_momentum=0.9):
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

i using:
CUDA Version: 11.7
Python 3.8.10
torch 1.13.0+cu116
torchvision 0.14.0+cu116

I would very appreciate any ideas!