GPU out of memory for the decoder

Hi I am using LSTMcell as the decoder to complete a task similar to machine translation. However, I met the following GPU out of memory problem:

RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 23.64 GiB total capacity; 18.54 GiB already allocated; 8.06 MiB free; 22.44 GiB reserved in total by PyTorch)

I checked the gpu memory usage during the running time of the decoder with print(f"allcoated {torch.cuda.memory_allocated(0)/(1024**3)} GB") for the my customized attention module, and saw the memory usage keeps climbing.

Detaild error message points to the Attention module:

<ipython-input-15-ddbbf84c3dca> in forward(self, encoder_out, decoder_hidden)
     29         print("---")
     30         # alpha: calculated base level attention
---> 31         alpha = self.softmax(att) # (batch, sequence_len); row sum = 1
     32         attention_weighted_encoding = encoder_out * alpha.unsqueeze(2)
     33         attention_weighted_encoding = torch.sum(attention_weighted_encoding, dim=1) # (batch, encoder_dim)

~/anaconda3/envs/torch_dev/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

~/anaconda3/envs/torch_dev/lib/python3.8/site-packages/torch/nn/modules/activation.py in forward(self, input)
   1196 
   1197     def forward(self, input: Tensor) -> Tensor:
-> 1198         return F.softmax(input, self.dim, _stacklevel=5)
   1199 
   1200     def extra_repr(self) -> str:

~/anaconda3/envs/torch_dev/lib/python3.8/site-packages/torch/nn/functional.py in softmax(input, dim, _stacklevel, dtype)
   1510         dim = _get_softmax_dim('softmax', input.dim(), _stacklevel)
   1511     if dtype is None:
-> 1512         ret = input.softmax(dim)
   1513     else:
   1514         ret = input.softmax(dim, dtype=dtype)

Here is the code of the attention module:

class Attention(nn.Module):
    """
    Attention network for calculate attention value
    """
    
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        """
        :param encoder_dim: input size of encoder network
        :param decoder_dim: input size of decoder network
        :param attention_dim: input size of attention network
        """
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out) # (batch, sequence_len, encoder_dim) -> (batch, sequence_len, attention_dim)
        att2 = self.decoder_att(decoder_hidden) # (batch, decoder_dim) -> (batch, attention_dim)
        print(att1.shape, att2.shape)
        print(f"allcoated {torch.cuda.memory_allocated(0)/(1024**3)} GB")
        # base level attention
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch, sequence_len)
        print(f"allcoated {torch.cuda.memory_allocated(0)/(1024**3)} GB")
        print("---")
        # alpha: calculated base level attention
        alpha = self.softmax(att) # (batch, sequence_len); row sum = 1
        attention_weighted_encoding = encoder_out * alpha.unsqueeze(2)
        attention_weighted_encoding = torch.sum(attention_weighted_encoding, dim=1) # (batch, encoder_dim)
        return attention_weighted_encoding, alpha

And here is how decoder works in my code:

        # predict sequence
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            # attention_weighted_encoding: (batch_size_t, encoder_dim)
            # alpha: (batch_size_t, num_of_pixels)
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            # new hidden and new cell state
            h, c = self.decode_step(
                        torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                        (h[:batch_size_t], c[:batch_size_t])
            )
            preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

I appreciate any input that could help me solve this issue. Thanks in advance!