Optimize calculation of sub tensors

I’m trying to optimize a piece of PyTorch code that causes the computation graph to become really large. I’m not sure if it’s possible, which is why I’d like some expert advice on this.

It is part of a toolkit for multi-output Gaussian processes, where data points can be in different “channels”, indicated by an integer in the first column of X (ie. channel index is in the range of [0,channels)). Lets say we have K channels and N input data points, for each combination of channels (K*K) we calculate and populate a sub tensor of the final tensor (shape (N,N)).

# get channel indices for all data points
c = X[:,0].long() 

# create channel mask for X
m = [c==i for i in range(channels)]

# select X input for each channel
x = [X[m[i],1:] for i in range(channels)]

# find indices into final tensor for the elements of sub tensor
r = [torch.nonzero(m[i]) for i in range(channels)]

# final tensor
res = torch.empty(X.shape[0], X.shape[0])  # NxN

# loop over channel combinations
for i in range(channels):
    for j in range(channels):
        k = calculate_subtensor(i, j, x[i], x[j])
        res[r[i],r[j]] = k # k has shape (x[i].shape[0], x[j].shape[0])

The computation graph will have many Select, Unsqueeze, Permute, and View nodes, but the problem is that the number of nodes scales quadratically with channels. Can this code be improved?

Note that the channel index increases monotonically along the data points, something this code doesn’t make use of.

The problem is likely the fact that the loops are in Python, which makes my CPU at a 100% but my GPU only 60%. Surely, the loops can and should be unrolled so that the full computation graph can be generated. I suppose I need to play around with torch.jit?