Reordering colum breaks gradient


I created a function in a multimodal model, which is supposed to reorder columns based on their correlation
X1 and X2 are the embeddings of the two views.

Below is the function:

        def get_corr_idx(self,X1, X2, n_rows):
        corr_matrix = torch.corrcoef(torch.concat([X1.t(), X2.t()]))
        # take lower matrix without diagonal
        corr_matrix = corr_matrix[-n_rows:, :n_rows]
        # make matrix absolute cause we dont care about negative or positive correlation
        corr_matrix = torch.abs(corr_matrix)
        view1_idx = []
        view2_idx = []
        for i in range(0, n_rows):
            arg_max = torch.argmax(corr_matrix)
            row = arg_max.item() / n_rows
            row = int(row)
            column = arg_max.item() % n_rows
            corr_matrix[:, column] = 0
            corr_matrix[row, :] = 0
        a = torch.LongTensor(view1_idx).to(torch.device("cuda:0"))
        b = torch.LongTensor(view2_idx).to(torch.device("cuda:0"))
        X1_trans = torch.index_select(X1, 1, a)
        X2_trans = torch.index_select(X2, 1, b)
        return X1_trans, X2_trans

I think the problem here is that grad_fn of “a” and “b” is None, which breaks the gradient. Because of that, my model stops learning and the losses stagnate.

Does anyone have an suggestion on how to do this?

Thank you.


it’s not about reordering.
When you call


you are reinstantiating the tensor and breaking the attached graph.
you can cast to long via maybe tensor.long()

In addition to what @JuanFMontesinos correctly pointed out, I’m also thinking about int(row) which makes row an int object.

That to say even if you use what Juan has suggested, (so that the data isn’t copied reinstantiating the tensor) that int object (row) shall anyway have no graph attached to it - graphs are for tensor objects only.
Makes sense?

Also, I do not fully know or understand your use-case, but here’s something -

you are constructing a and b as long type and also wanting them to have their requires_grad=True or something, essentially wanting them to be somewhere in the gradient calculation thing. But, from what I know only tensors of data type = float or some complex type can have their requires_grad=True.

Not very sure if this would cause any errors for your use case though.

Hoping it helps,