Reordering colum breaks gradient

Hello

I created a function in a multimodal model, which is supposed to reorder columns based on their correlation
X1 and X2 are the embeddings of the two views.

Below is the function:

        def get_corr_idx(self,X1, X2, n_rows):
        corr_matrix = torch.corrcoef(torch.concat([X1.t(), X2.t()]))
        # take lower matrix without diagonal
        corr_matrix = corr_matrix[-n_rows:, :n_rows]
        # make matrix absolute cause we dont care about negative or positive correlation
        corr_matrix = torch.abs(corr_matrix)
        view1_idx = []
        view2_idx = []
        for i in range(0, n_rows):
            arg_max = torch.argmax(corr_matrix)
            row = arg_max.item() / n_rows
            row = int(row)
            column = arg_max.item() % n_rows
            view1_idx.append(column)
            view2_idx.append(row)
            corr_matrix[:, column] = 0
            corr_matrix[row, :] = 0
        a = torch.LongTensor(view1_idx).to(torch.device("cuda:0"))
        b = torch.LongTensor(view2_idx).to(torch.device("cuda:0"))
        X1_trans = torch.index_select(X1, 1, a)
        X2_trans = torch.index_select(X2, 1, b)
        return X1_trans, X2_trans

I think the problem here is that grad_fn of “a” and “b” is None, which breaks the gradient. Because of that, my model stops learning and the losses stagnate.

Does anyone have an suggestion on how to do this?

Thank you.

Hilmar

it’s not about reordering.
When you call

torch.LongTensor

you are reinstantiating the tensor and breaking the attached graph.
you can cast to long via tensor.int() maybe tensor.long()

In addition to what @JuanFMontesinos correctly pointed out, I’m also thinking about int(row) which makes row an int object.

That to say even if you use what Juan has suggested, (so that the data isn’t copied reinstantiating the tensor) that int object (row) shall anyway have no graph attached to it - graphs are for tensor objects only.
Makes sense?

Also, I do not fully know or understand your use-case, but here’s something -

you are constructing a and b as long type and also wanting them to have their requires_grad=True or something, essentially wanting them to be somewhere in the gradient calculation thing. But, from what I know only tensors of data type = float or some complex type can have their requires_grad=True.

Not very sure if this would cause any errors for your use case though.

Hoping it helps,
Srishti