I want to conduct self attention on the tensor x, and OOM error occur, I want to know what is the right way to do self attetion on a long sequence such as 400.
I use the K80 GPU which has 12Gb memory
code bellow has a OOM error
import torch
import torch.nn as nn
bsz = 32 # batch size
seqlen = 400 # sequeeze length
hdim = 300 # hidden dimension
x = torch.Tensor(bsz, seqlen, hdim).cuda()
attention_query = x.unsqueeze(1).expand(-1, seqlen, -1, hdim)
attention_key = x.unsqueeze(2).expand(-1, -1, seqlen, -1)
a = torch.cat([attention_key, attention_key* attention_query, attention_query], dim=3)
similarity_matrix = nn.Linear(300 * 3, 1)(a)
I tried another way
import torch
import torch.nn as nn
bsz = 32
seqlen = 400
hdim = 300
x = torch.Tensor(bsz, seqlen, hdim).cuda()
similarity_matrix = torch.Tensor(bsz, seqlen, seqlen)
lin = nn.Linear(hdim *3, 1).cuda()
for i in range(seqlen):
attention_query = x[:, i, :]
out = attention_query.unsqueeze(1).expand(-1, seqlen, -1)
subsim = lin( torch.cat([x, out, out*x], dim=2) ).squeeze(2)
similarity_matrix[:, i, :] = subsim
This still has error
THCudaCheck FAIL file=/pytorch/aten/src/THC/generic/THCStorage.cu line=58 error=2 : out of memory
Traceback (most recent call last):
File "test.py", line 18, in <module>
subsim = lin( torch.cat([x, out, out*x], dim=2) ).squeeze(2)
RuntimeError: cuda runtime error (2) : out of memory at /pytorch/aten/src/THC/generic/THCStorage.cu:58