Code segment is here.
class SelfAtt(nn.Module):
def __init__(self, input_size):
super(SelfAtt, self).__init__()
self.input_size = input_size
self.linear = nn.Linear(input_size, 1)
self.linear_u = nn.Linear(input_size, input_size)
self.linear_w = nn.Linear(input_size, input_size)
def forward(self, x, x_mask):
"""
x: batchsize * len1 * hidden
:param x:
:return: batchsize * len1 * hidden
"""
len1 = x.size(1)
ctemp = []
for i in range(len1):
x1 = x[:, i, :]
input1 = self.linear_u(x.view(-1, self.input_size))
input2 = self.linear_w(x1.repeat(1, len1, 1).view(-1, self.input_size))
#input1 = x.view(-1, self.input_size)
#input2 = x1.repeat(1, len1, 1).view(-1, self.input_size)
hidden = F.relu(input1 + input2) # [b * len1]
score = self.linear(hidden).view(x.size(0), x.size(1))
score.data.masked_fill_(x_mask.data, -float('inf'))
alpha = F.softmax(score)
c = alpha.unsqueeze(1).bmm(x).squeeze(1)
ctemp.append(c)
return torch.stack(ctemp).transpose(0, 1)
When i use this module in my network. It will rapidly cause cuda out of memory.
How can i solve this problem?