Hello
I created a function in a multimodal model, which is supposed to reorder columns based on their correlation
X1 and X2 are the embeddings of the two views.
Below is the function:
def get_corr_idx(self,X1, X2, n_rows):
corr_matrix = torch.corrcoef(torch.concat([X1.t(), X2.t()]))
# take lower matrix without diagonal
corr_matrix = corr_matrix[-n_rows:, :n_rows]
# make matrix absolute cause we dont care about negative or positive correlation
corr_matrix = torch.abs(corr_matrix)
view1_idx = []
view2_idx = []
for i in range(0, n_rows):
arg_max = torch.argmax(corr_matrix)
row = arg_max.item() / n_rows
row = int(row)
column = arg_max.item() % n_rows
view1_idx.append(column)
view2_idx.append(row)
corr_matrix[:, column] = 0
corr_matrix[row, :] = 0
a = torch.LongTensor(view1_idx).to(torch.device("cuda:0"))
b = torch.LongTensor(view2_idx).to(torch.device("cuda:0"))
X1_trans = torch.index_select(X1, 1, a)
X2_trans = torch.index_select(X2, 1, b)
return X1_trans, X2_trans
I think the problem here is that grad_fn of “a” and “b” is None, which breaks the gradient. Because of that, my model stops learning and the losses stagnate.
Does anyone have an suggestion on how to do this?
Thank you.
Hilmar