Out of memory when conduct self attention on long sequence


I want to conduct self attention on the tensor x, and OOM error occur, I want to know what is the right way to do self attetion on a long sequence such as 400.
I use the K80 GPU which has 12Gb memory
code bellow has a OOM error

import torch
import torch.nn as nn

bsz = 32 # batch size
seqlen = 400 # sequeeze length
hdim = 300  # hidden dimension

x = torch.Tensor(bsz, seqlen, hdim).cuda()
attention_query = x.unsqueeze(1).expand(-1, seqlen, -1, hdim)
attention_key = x.unsqueeze(2).expand(-1, -1, seqlen, -1)
a = torch.cat([attention_key, attention_key* attention_query, attention_query], dim=3)
similarity_matrix = nn.Linear(300 * 3, 1)(a)

I tried another way

import torch
import torch.nn as nn

bsz = 32
seqlen = 400
hdim = 300

x = torch.Tensor(bsz, seqlen, hdim).cuda()
similarity_matrix = torch.Tensor(bsz, seqlen, seqlen)
lin = nn.Linear(hdim *3, 1).cuda()
for i in range(seqlen):
    attention_query = x[:, i, :]
    out = attention_query.unsqueeze(1).expand(-1, seqlen, -1)
    subsim = lin( torch.cat([x, out,  out*x], dim=2) ).squeeze(2)
    similarity_matrix[:, i, :] = subsim

This still has error

THCudaCheck FAIL file=/pytorch/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
  File "test.py", line 18, in <module>
    subsim = lin( torch.cat([x, out,  out*x], dim=2) ).squeeze(2)
RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCStorage.cu:58

(Robin Wang) #2

I come up with the same question, have you got a solution?