Slow backpropagation in CPU

Hi!
I’m used to the idea that backpropagation takes roughly a couple of times more than the forward prediction to run (please tell me if this is an unrealistic assumption).

Currently, I’m doing a simple linear regression where an unknown vector x is multiplied by a known matrix A to get a prediction y'. Then I’m minimizing the MSE loss between the true y and the predicted y'.
The matrix A is quite large but sparse, and therefore is saved in a torch.sparse.FloatTensor() in the CPU.

The training steps are

start_time = time.time()
y_pred = torch.matmul(A, x)
print('Matrix product done in ' + str(time.time() - start_time))

loss = loss_fn(RF.mask*RF.x_data, RF.mask*RF_pred)

start_time = time.time()
self.optimizer.zero_grad()
loss.backward()
print('Back propagation done in ' + str(time.time() - start_time))

The output I get is

Matrix product done in 0.36480283737182617
Back propagation done in 25.860530376434326

Is it actually possible that backpropagation takes so long to compute? Is there anything I can try to speed up, bearing in mind that A won’t fit the GPU memory?

Hi Dick!

What you report does seem surprising. I don’t know enough
about torch.sparse and autograd to have an opinion about
whether this would be expected, but it shouldn’t have to be
this way.

You report that calculating the gradient (loss.backward())
takes about one hundred times as long as calculating the
loss itself.

Here’s how I understand what you have written:

You have a known vector of length n, y, and a sparse n x n
matrix, A. I’ll assume that A is sparse in the conventional
sense that it has order-n non-zero elements, and that it is
stored in such a way that matrix-vector multiplication (e.g.,
A . x) can be performed in order-n time.

You also have a “hidden” length-n vector, x, such that y = A . x,
and that you are trying recover x by performing a gradient-descent
optimization. For a given approximation to x, x’, you calculate
y’ = A . x’. Let’s denote the “residuals” by r = y’ - y. The MSE loss
is then sum_{i = 1, n} r_i^2.

A is fixed (and known); that is, you’re only trying to optimize for
x. Therefore, for a given x’, you need to calculate the gradient
of loss with respect to x’, that is, grad_i = d loss / d x’_i.

Taking the derivative, you get grad_i = 2 (A-transpose . r)_i.
Because A is sparse, calculating y’, the residual r, and the
loss all take order-n time. (If pytorch is smart enough to
remember r, you don’t have to recalculate it, but even so, it’s
only order-n.)

A is sparse, so A-transpose is too. (Supposedly pytorch
knows this, but maybe not.) So calculating the gradient is
also order-n, because it requires only sparse-matrix–vector
multiplication.

So I would expect loss.backward() to take about as long as
calculating loss. (Maybe twice as long if it has to recalculate
r. Maybe three or four times as long if it has to monkey around
traversing the autograd graph.)

You don’t say what the dimension of your matrix is, but if it
were, say, n = 100, and loss.backward() somehow wasn’t
able to use sparse-matrix multiplication, then loss.backward
would be order-n^2 (full-matrix–vector multiplication), which
would be about 100 times slower than the order-n calculation
of loss.

As written in your post, you don’t actually use y_pred in your
gradient calculation, you don’t say what the dimensionality of
your problem is, you don’t say what RF_mask is, and you don’t
say which of your tensors have requires_grad = True.

But if your actual calculation is as I laid out above, the
loss.backward() calculation could be done in order-n
and take about as long as the loss calculation. Whether
autograd actually does it in order-n, I don’t know.

In short, if loss.backward() manages to use sparse-matrix
multiplication, it should run about as fast as calculating loss.
But if the matrix multiplication somehow gets bumped up to
full-matrix multiplication, then loss.backward() will take about
n times longer, which for “quite large” A could be a lot longer.

I, for one, would love to hear from our pytorch / autograd
experts about whether we should expect autograd to “play
smart” with sparse matrices.

Best regards.

K. Frank

Hi Frank. You summarized the problem perfectly, thanks a lot for your time!!

Just to clarify:

  • Indeed I’m basically finding the hidden vector x that minimizes the residual |y - Ax|
  • I do have set requires_grad = False for the matrix A
  • I assume torch knows that the transpose of a sparse matrix is sparse too, otherwise I would have had memory issues (did I mention that my matrix is huge? :grimacing: )
  • To give some context: my matrix A is roughly 850.000 x 210.000 with about 0.1% of non-zero coefficients. For this reason, I don’t think that a full matrix A is ever computed, otherwise my RAM will simply be too small to hold it.
  • I forgot to remove RF.mask, sorry for that. That is a matrix of one and zeros that allows me to weight the MSE loss in order to avoid some of the data. The situation doesn’t change if I remove it (its requires_grad is set to False)

A is sparse, so A-transpose is too. (Supposedly pytorch
knows this, but maybe not.)

I think you have find the culprit here: something weird happens to sparse matrices when you transpose them.

I have coded what the backpropagation algorithm is supposed to do, and I timed it, as follows:

tic = time.time()
with torch.no_grad():
    RF_pred = torch.sparse.mm(FO.A_torch, y)
    toc1 = time.time()
    print('Matrix-vector product (sec)' + str(toc1-tic))

    residual = RF.mask*(RF_pred - RF.x_data)
    toc2 = time.time()
    print('Residual calculation + masking (sec)' + str(toc2-toc1))

    At = torch.t(FO.A_torch)
    toc3 = time.time()
    print('Transpose calculation (sec)' + str(toc3 - toc2))

    grad_y = torch.sparse.mm(At, residual)
    toc4 = time.time()
    print('Matrix (A transpose)-vector product (sec)' + str(toc4 - toc3))

    y.grad = grad_y

The result I get is the following

Matrix-vector product (sec)0.3793952465057373
Residual calculation + masking (sec)0.0005414485931396484
Transpose calculation (sec)1.4321186542510986
Matrix (A transpose)-vector product (sec)24.383739471435547

First of all: taking the transpose takes waaaaaay too much time than it should.
Then, calculating the matrix product using the transposed matrix is what takes forever.
Maybe we should raise an issue on github about it?

Hello Dick!

Those timings do look suspicious.

Your A-transpose . v timing is again about 100 times longer
than your A . v timing. This is in line with the backward / forward
timing discrepancy you reported earlier. And I agree that the
timing for taking the transpose also seems long.

It’s almost as if taking the transpose turns the matrix into a
full matrix (but your comment about memory suggests that
this is not the case).

I did a quick test with a (small) torch.sparse.FloatTensor.
Taking the transpose (torch.t()) resulted in a sparse tensor,
so my simple-minded theory about the transpose being a full
matrix looks unlikely. I think it’s a long shot, but could you print
out the type of your At variable just to make sure that it’s still
sparse (and that nothing else fishy is going on)?

A github issue would seem appropriate. Or perhaps some
pytorch experts – who seem somewhat sparse around here
sometimes – could chime in on this thread with any ideas
about what might be going on.

I do think what you’re doing should work. I can’t think of any
impediment to using the much faster sparse-matrix manipulations
with A-transpose (or with the back-propagation, in general).

However, quoting the TORCH.SPARSE documentation:

This API is currently experimental and may change in the near future.

So maybe this is just a (hopefully known) limitation of a work
in progress.

Best.

K. Frank