CUDA out of memory when using vision transformer

I am getting CUDA out of memory when using vision transformer. I have changed my batch size from 8 to 1 and still get the same error:

attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale

RuntimeError: CUDA out of memory. Tried to allocate 1.15 GiB (GPU 0; 10.76 GiB total capacity; 8.33 GiB already allocated; 1.13 GiB free; 8.49 GiB reserved in total by PyTorch)

This is the self attention class:

class SelfAttention(nn.Module):
    def __init__(self, in_dim, heads=8, dropout_rate=0.1):
        super(SelfAttention, self).__init__()
        self.heads = heads
        self.head_dim = in_dim // heads
        self.scale = self.head_dim ** 0.5
        
        self.query = LinearGeneral((in_dim,), (self.heads, self.head_dim))
        self.key = LinearGeneral((in_dim,), (self.heads, self.head_dim))
        self.value = LinearGeneral((in_dim,), (self.heads, self.head_dim))
        self.out = LinearGeneral((self.heads, self.head_dim), (in_dim,))

        if dropout_rate > 0:
            self.dropout = nn.Dropout(dropout_rate)
        else:
            self.dropout = None

    def forward(self, x):
        b, n, _ = x.shape

        q = self.query(x, dims=([2], [0]))
        k = self.key(x, dims=([2], [0]))
        v = self.value(x, dims=([2], [0]))

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        out = torch.matmul(attn_weights, v)
        out = out.permute(0, 2, 1, 3)

        out = self.out(out, dims=([2, 3], [0, 1]))

        return out

Could you suggest how to handle the OOM problem when using vision transformers?

I initially see this in my output:

[6/176] train loss: 0.890; agg acc: 0.667
[[3. 0. 0.]
 [2. 1. 0.]
 [0. 0. 0.]]

but the computations afterward ends up with OOM.

Also, I am running ViT with only 1 layers:

if __name__ == '__main__':
    #model = VisionTransformer(num_layers=2)
    model = VisionTransformer(num_layers=1)
    x = torch.randn((2, 3, 256, 256))
    out = model(x)

Reducing number of layers, still yields OOM but a bit different error:

torch.Size([1260, 512])
pred is:  tensor([1], device='cuda:0')
Traceback (most recent call last):
  File "main.py", line 125, in <module>
    loss.backward()
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/_tensor.py", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/SeaExp/mona/venv/dpcc/lib/python3.8/site-packages/torch/autograd/__init__.py", line 147, in backward
    Variable._execution_engine.run_backward(
RuntimeError: CUDA out of memory. Tried to allocate 1.16 GiB (GPU 0; 10.76 GiB total capacity; 7.33 GiB already allocated; 1.08 GiB free; 7.40 GiB reserved in total by PyTorch)

Hello! Were you able to resolve this?