Cuda OOM after slight modifications

Hi everyone,
At this point im desperate and don’t really know what happened.
My model is an implementation of the Swin-Transformer and i did not modify the code of the model, I assume their implementation does not cause this error. Instead I use another class to call the model and the head and this all worked fine. However I realised that I need to do multiclass detection, so i switched out the loss function in my Pytorch Lightning Module for the BCELoss. The strcuture did not change at all:

class LitModule3D(pl.LightningModule):
    def __init__(self):
        self.classifier = swin_encoder(1, ckpt_path, device)
        self.optimizer = torch.optim.AdamW(self.classifier.parameters(), lr =0.0001, weight_decay=0.0001)

    def forward(self, x):
        output = self.classifier(x)
        return self.sigmoid(output)

    def training_step(self, batch, batch_idx):

        x, l = batch
        feat = self.classifier(x)
        propabilities = self.sigmoid(feat)
        loss = self.loss_func(propabilities, l)

        #acc = self.accuracy(feat,l)
        #self.log('train_accuracy', acc, sync_dist=True)
        self.log('loss/loss', loss, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

After these modifications i used the same PL Module to train another model, that worked too, but the 3D version always throws this error, even with distributed data:

    z = self.encoder(x)
  File "/home/usr/anaconda3/envs/swin/lib/python3.7/site-packages/torch/nn/modules/", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/usr/Documents/Code/Swin/", line 64, in forward
    feat = self.model.forward(imgs)
  File "/home/usr/Documents/Code/Swin/", line 566, in forward
    x = layer(x.contiguous())
  File "/home/usr/anaconda3/envs/swin/lib/python3.7/site-packages/torch/nn/modules/", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/usr/Documents/Code/Swin/", line 405, in forward
    x = blk(x, attn_mask)
  File "/home/usr/anaconda3/envs/swin/lib/python3.7/site-packages/torch/nn/modules/", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/usr/Documents/Code/Swin/", line 268, in forward
    x = self.forward_part1(x, mask_matrix)
  File "/home/usr/Documents/Code/Swin/", line 240, in forward_part1
    attn_windows = self.attn(x_windows, mask=attn_mask)  # B*nW, Wd*Wh*Ww, C
  File "/home/usr/anaconda3/envs/swin/lib/python3.7/site-packages/torch/nn/modules/", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/usr/anaconda3/envs/swin/lib/python3.7/site-packages/torch/amp/", line 12, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/usr/Documents/Code/Swin/", line 155, in forward
    attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
RuntimeError: CUDA out of memory. Tried to allocate 76.00 MiB (GPU 0; 10.92 GiB total capacity; 5.76 GiB already allocated; 20.44 MiB free; 5.82 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I really hope someone can help me with these infos. This is the code that causes this error:

    def forward(self, x, mask=None):
        """ Forward function.
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, N, N) or None
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # B_, nH, N, C

        q = q * self.scale
        attn = q @ k.transpose(-2, -1)

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index[:N, :N].reshape(-1)].reshape(
            N, N, -1)  # Wd*Wh*Ww,Wd*Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wd*Wh*Ww, Wd*Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)
        #import pdb; pdb.set_trace()

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        #x = x.half()

        x = self.proj(x)
        x = self.proj_drop(x)
        return x