Attention matrix in Python with PyTorch

I want to implement Q&A systems with attention mechanism. I have two inputs; context and query which shapes are (batch_size, context_seq_len, embd_size) and (batch_size, query_seq_len, embd_size).
I am following the below paper.
Machine Comprehension Using Match-LSTM and Answer Pointer. https://arxiv.org/abs/1608.07905

Then, I want to obtain a attention matrix which shape is (batch_size, context_seq_len, query_seq_len, embd_size). In the thesis, they calculate values for each row (it means each context word, G_i, alpha_i in the paper).

My code is below and it is running. But I am not sure my way is good or not. For example, I use for loop for generating sequence data (for i in range(T):). And to obtain each row, I use in-place operator like G[:,i,:,:], embd_context[:,i,:].clone() is a good manner in pytorch? If not, where should I change the code?

And if you notice other points, let me know. I am a new in this field and pytorch. Sorry for my ambiguous question.

class MatchLSTM(nn.Module):
    def __init__(self, args):
        super(MatchLSTM, self).__init__()
        self.embd_size = args.embd_size
        d = self.embd_size
        self.answer_token_len = args.answer_token_len
        
        self.embd = WordEmbedding(args)
        self.ctx_rnn   = nn.GRU(d, d, dropout = 0.2)
        self.query_rnn = nn.GRU(d, d, dropout = 0.2)
        
        self.ptr_net = PointerNetwork(d, d, self.answer_token_len) # TBD

        self.w  = nn.Parameter(torch.rand(1, d, 1).type(torch.FloatTensor), requires_grad=True) # (1, 1, d)
        self.Wq = nn.Parameter(torch.rand(1, d, d).type(torch.FloatTensor), requires_grad=True) # (1, d, d)
        self.Wp = nn.Parameter(torch.rand(1, d, d).type(torch.FloatTensor), requires_grad=True) # (1, d, d)
        self.Wr = nn.Parameter(torch.rand(1, d, d).type(torch.FloatTensor), requires_grad=True) # (1, d, d)

        self.match_lstm_cell = nn.LSTMCell(2*d, d)

    def forward(self, context, query):
        # params
        d = self.embd_size
        bs = context.size(0) # batch size
        T = context.size(1)  # context length 
        J = query.size(1)    # query length

        # LSTM Preprocessing Layer
        shape = (bs, T, J, d)
        embd_context     = self.embd(context)         # (N, T, d)
        embd_context, _h = self.ctx_rnn(embd_context) # (N, T, d)
        embd_context_ex  = embd_context.unsqueeze(2).expand(shape).contiguous() # (N, T, J, d)
        embd_query       = self.embd(query)           # (N, J, d)
        embd_query, _h   = self.query_rnn(embd_query) # (N, J, d)
        embd_query_ex  = embd_query.unsqueeze(1).expand(shape).contiguous() # (N, T, J, d)

        # Match-LSTM layer
        G = to_var(torch.zeros(bs, T, J, d)) # (N, T, J, d)
        
        wh_q = torch.bmm(embd_query, self.Wq.expand(bs, d, d)) # (N, J, d) = (N, J, d)(N, d, d)

        hidden     = to_var(torch.randn([bs, d])) # (N, d)
        cell_state = to_var(torch.randn([bs, d])) # (N, d)
        # TODO bidirectional
        H_r = [hidden]
        for i in range(T):
            wh_p_i = torch.bmm(embd_context[:,i,:].clone().unsqueeze(1), self.Wp.expand(bs, d, d)).squeeze() # (N, 1, d) -> (N, d)
            wh_r_i = torch.bmm(hidden.unsqueeze(1), self.Wr.expand(bs, d, d)).squeeze() # (N, 1, d) -> (N, d)
            sec_elm = (wh_p_i + wh_r_i).unsqueeze(1).expand(bs, J, d) # (N, J, d)

            G[:,i,:,:] = F.tanh( (wh_q + sec_elm).view(-1, d) ).view(bs, J, d) # (N, J, d) # TODO bias

            attn_i = torch.bmm(G[:,i,:,:].clone(), self.w.expand(bs, d, 1)).squeeze() # (N, J)
            attn_query = torch.bmm(attn_i.unsqueeze(1), embd_query).squeeze() # (N, d) 
            z = torch.cat((embd_context[:,i,:], attn_query), 1) # (N, 2d)

            hidden, cell_state = self.match_lstm_cell(z, (hidden, cell_state)) # (N, d), (N, d)
            H_r.append(hidden)
        H_r = torch.stack(H_r, dim=1) # (N, T, d)

        indices = self.ptr_net(H_r) # (N, M, T) , M means (start, end)
        return indices

I hope someone review my code, pleaseā€¦