Arbitrary pairwise operation

Hello! I am attempting to compute a dense N x N matrix by evaluating a function f(x_i, x_j) on the N^2 elements in X (X.shape = (N, F)). An example would be calculating an euclidean distance matrix (if f was euclidean distance) but in general I would like for f to be a more complicated function. However, I’m really struggling to find a way to do this without blowing up my memory in pytorch. For my application, N is ~2000 and F is ~128. My first (naive) way is to generate two index vectors, idx_r and idx_c,

idx_r = [0, 0, 0, ...., 1, 1, 1, ....N-1,N-1,N-1]
idx_c = [0, 1, 2, 3,... 0,1, 2, 3, ...., N-3, N-2, N-1]

and then compute:

X_r = X[idx_r]
X_c = X[idx_c]  
y = f(X_r, X_c)

but the explicit instantiation of X_r and X_c easily consumes all of my GPU memory. Are there any recommended pytorch tricks that will let me compute the elements of X_r and X_c on the fly? or any other suggested ways of dealing with this?


1 Like