Hi I am using LSTMcell as the decoder to complete a task similar to machine translation. However, I met the following GPU out of memory problem:
RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 23.64 GiB total capacity; 18.54 GiB already allocated; 8.06 MiB free; 22.44 GiB reserved in total by PyTorch)
I checked the gpu memory usage during the running time of the decoder with print(f"allcoated {torch.cuda.memory_allocated(0)/(1024**3)} GB")
for the my customized attention module, and saw the memory usage keeps climbing.
Detaild error message points to the Attention module:
<ipython-input-15-ddbbf84c3dca> in forward(self, encoder_out, decoder_hidden)
29 print("---")
30 # alpha: calculated base level attention
---> 31 alpha = self.softmax(att) # (batch, sequence_len); row sum = 1
32 attention_weighted_encoding = encoder_out * alpha.unsqueeze(2)
33 attention_weighted_encoding = torch.sum(attention_weighted_encoding, dim=1) # (batch, encoder_dim)
~/anaconda3/envs/torch_dev/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
--> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),
~/anaconda3/envs/torch_dev/lib/python3.8/site-packages/torch/nn/modules/activation.py in forward(self, input)
1196
1197 def forward(self, input: Tensor) -> Tensor:
-> 1198 return F.softmax(input, self.dim, _stacklevel=5)
1199
1200 def extra_repr(self) -> str:
~/anaconda3/envs/torch_dev/lib/python3.8/site-packages/torch/nn/functional.py in softmax(input, dim, _stacklevel, dtype)
1510 dim = _get_softmax_dim('softmax', input.dim(), _stacklevel)
1511 if dtype is None:
-> 1512 ret = input.softmax(dim)
1513 else:
1514 ret = input.softmax(dim, dtype=dtype)
Here is the code of the attention module:
class Attention(nn.Module):
"""
Attention network for calculate attention value
"""
def __init__(self, encoder_dim, decoder_dim, attention_dim):
"""
:param encoder_dim: input size of encoder network
:param decoder_dim: input size of decoder network
:param attention_dim: input size of attention network
"""
super(Attention, self).__init__()
self.encoder_att = nn.Linear(encoder_dim, attention_dim)
self.decoder_att = nn.Linear(decoder_dim, attention_dim)
self.full_att = nn.Linear(attention_dim, 1)
self.relu = nn.ReLU()
self.softmax = nn.Softmax(dim=1)
def forward(self, encoder_out, decoder_hidden):
att1 = self.encoder_att(encoder_out) # (batch, sequence_len, encoder_dim) -> (batch, sequence_len, attention_dim)
att2 = self.decoder_att(decoder_hidden) # (batch, decoder_dim) -> (batch, attention_dim)
print(att1.shape, att2.shape)
print(f"allcoated {torch.cuda.memory_allocated(0)/(1024**3)} GB")
# base level attention
att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch, sequence_len)
print(f"allcoated {torch.cuda.memory_allocated(0)/(1024**3)} GB")
print("---")
# alpha: calculated base level attention
alpha = self.softmax(att) # (batch, sequence_len); row sum = 1
attention_weighted_encoding = encoder_out * alpha.unsqueeze(2)
attention_weighted_encoding = torch.sum(attention_weighted_encoding, dim=1) # (batch, encoder_dim)
return attention_weighted_encoding, alpha
And here is how decoder works in my code:
# predict sequence
for t in range(max(decode_lengths)):
batch_size_t = sum([l > t for l in decode_lengths])
# attention_weighted_encoding: (batch_size_t, encoder_dim)
# alpha: (batch_size_t, num_of_pixels)
attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim)
attention_weighted_encoding = gate * attention_weighted_encoding
# new hidden and new cell state
h, c = self.decode_step(
torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
(h[:batch_size_t], c[:batch_size_t])
)
preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size)
predictions[:batch_size_t, t, :] = preds
alphas[:batch_size_t, t, :] = alpha
I appreciate any input that could help me solve this issue. Thanks in advance!