Bottleneck in training due to double for loop

Greetings from Italy!

So, I coded a custom loss function to train my autoencoder. The problem is the following:
the function i’m using was coded by someone else, and, for my purposes, it needs to work on vectors.
For instance, if I have a dataset n x m where n is the number of samples, and m is the number of features, I need the function to work on every possible combination of features: so I’ll pass to the function the features number 0 and 1, 0 and 2 and so on.
So I did write a double for loop to see if the funtion would have worked inside torch. The problem is that the double for loop is being a huge bottleneck for my model, and the training is really slow.
I tried to optimize the double for loop + function section, to vectorialize it, but I wans’t able to.
This is the code:

def rbf_dot(self, pattern1, pattern2, deg):
        size1 = pattern1.size() 
        size2 = pattern2.size()

        G = torch.sum(pattern1 * pattern1, axis = 1).reshape(size1[0], 1)
        H = torch.sum(pattern2 * pattern2, axis = 1).reshape(size2[0], 1)

        Q = torch.tile(G, (1, size2[0]))
        R = torch.tile(torch.transpose(H, 0, 1), (size1[0], 1))

        H = Q + R - 2 * torch.matmul(pattern1, torch.transpose(pattern2, 0, 1))

        H = torch.exp(-(H/2) / (deg**2))

        return H

    def function(self, X, Y):
        X, Y : torch tensor of size [n, 1]
        n    : number of observations 
        n = X.size(0)
        # ----- width of X -----
        Xmed = X

        G = torch.sum(Xmed * Xmed, axis = 1).reshape(n, 1)
        Q = torch.tile(G, (1, n))
        R = torch.tile(torch.transpose(G, 0, 1), (n, 1))
        dists = Q + R - 2 * torch.matmul(Xmed, torch.transpose(Xmed, 0, 1))
        dists = dists - torch.tril(dists)
        dists = dists.reshape(n**2, 1)
        width_x = torch.sqrt( 0.5 * torch.median(dists[dists>0]))

        # ----- width of Y -----
        Ymed = Y

        G = torch.sum(Ymed * Ymed, axis = 1).reshape(n, 1)
        Q = torch.tile(G, (1, n))
        R = torch.tile(torch.transpose(G, 0, 1), (n, 1))

        dists = Q + R - 2 * torch.matmul(Ymed, torch.transpose(Ymed, 0, 1))
        dists = dists - torch.tril(dists)
        dists = dists.reshape(n**2, 1)
        width_y = torch.sqrt( 0.5 * torch.median(dists[dists>0]))

        # ----- -----
        H = torch.eye(n, device = dists.device) - ( torch.ones(n,n, device = dists.device) / n ) 

        K = self.rbf_dot(X, X, width_x)
        L = self.rbf_dot(Y, Y, width_y)

        Kc = torch.matmul(torch.matmul(H, K), H)
        Lc = torch.matmul(torch.matmul(H, L), H)

        testStat = torch.sum(torch.transpose(Kc, 0, 1) * Lc) / n

        return testStat

    def double_for_loop_section(self, data):
        # n : rows,     observations
        # m : columns,  features
        # so that: a row represents all the features of a single user;
        #          a column represents the same feature for all users

        n, m = data.size()
        stat = torch.zeros(m, m, device = data.device)
        # self.device

        for i in range(m):
            for j in range(i, m):
                if i != j:
                    x = data[:, i]
                    y = data[:, j]
                    x_2D = x[:, None] 
                    y_2D = y[:, None]
                    stat[i, j] = self.function(x_2D, y_2D)
                    # matrix is symmetric
                    stat[j, i] = stat[i, j]
        return stat

The double_for_loop section is called inside the loss function method, passing the hidden layer output of the encoder.
Could someone help me to vectorialize it?

Hi, give a more simplified version of your code, just a toy case that contains all the essential computation, maybe I can help.

1 Like

Hi Composite!

As an aside, your code is unhelpfully obtuse.

You loop over pairs of columns from data and then call self.function()
on that pair.

But up until the line in which you calculate testStat, Kc depends only
on the first column in the pair and Lc only on the second. So outside
of your double loop you can precompute everything that goes into Kc
and Lc.

First do this precomputation to clear away all of the confusion.

Then, to “vectorize” the computation of testStat and placing in into the
matrix stat, take a look at torch.einsum().

Good luck!

K. Frank

1 Like

Unofrtunately a simpler version of the code would be useless, all the operations you see are really important in computing what I need.
Btw I’ve been able to vectorize it :P, thanks anyway!

This is exactly what I was doing.
I didn’t write the code, I just translated it from numpy to torch.
I also realized the code was utterly redundant!
Btw, I’ve been able to vectorize it :stuck_out_tongue:
I just need to vectorize one last thing which, for now, I assumed being constant :smiley:
Thank you for the help, einsum was so cool to use :stuck_out_tongue: