Gradient failure due to inplace operation

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

Though have seen a couple of posts on forum regrading this topic, still confused of what is causing the trouble.

Here’s the piece of code I’m running and got the gradients runtime error. I’m trying to calculate a representation for each cluster by doing the average, where the clusters are mwe and non-mwe. Note, half_retrieval = int(num_retrieval/2), fw is linear layer and act is ReLU with inplace=True. Can someone help explain which part is causing the trouble?

    def forward(self, query, retrievals, sigma=4):

        #models
        PROTOTYPICAL = True
        #metrics
        EUCLIDEAN = True

        num_candidates, num_retrievals, embed_size = retrievals.size()
        half_retrieval = int(num_retrievals/2)

        #read embeddings
        q = self.embed(query)
        flat_r = self.embed(retrievals.view(-1, embed_size))
        
        #GRU
        out_query, _ = self.gru(q)  #(num_can, seq_len, hidden_size)
        out_query = self.act1(self.fw1(out_query.transpose_(2, 1)))
        out_query = self.act2(self.fw2(out_query))
        out_query = torch.squeeze(out_query)

        out_retrieval, _ = self.gru(flat_r) #(num_can*num_ret, seq_len, hidden_size)
        out_retrieval = self.act1(self.fw1(out_retrieval.transpose_(2, 1)))
        out_retrieval = self.act2(self.fw2(out_retrieval))
        out_retrieval = torch.squeeze(out_retrieval)


        if PROTOTYPICAL == True:
            out_retrieval = out_retrieval.view(num_candidates, num_retrievals, -1)
            mwe_half = out_retrieval[:, :half_retrieval, :].clone()
            non_mwe_half = out_retrieval[:, half_retrieval:, :].clone()

            rep_mwe = torch.mean(mwe_half, 1)
            rep_non_mwe = torch.mean(non_mwe_half, 1)

            if EUCLIDEAN == True:
                mwe_dist = torch.sum(torch.mul(rep_mwe, out_query), 1)
                non_mwe_dist = torch.sum(torch.mul(rep_non_mwe, out_query), 1)
                dist = torch.stack((mwe_dist, non_mwe_dist), 1)
                sim = -dist
                weights = F.softmax(sim, 1)
                attention_scores = torch.mul(weights, sim)


I found out it’s the transpose that causes the issue.