Advice for debugging slow backward pass

I am working with a recursive neural network where the forward pass takes roughly 2s on average, and the backward pass closer to 7 or 8s. Does this sound like normal behavior? I wonder what I could be doing which is causing such a slowdown.

I have a lot of narrow/chunk/cat in the model. Could this be a factor?

From what I’ve heard, backward pass does take ~ 3x more time than forward pass.

@Rinku_Jadhav2014 good to know! Is this a “rule”, or can anything be done to optimize the backward pass? Perhaps writing custom kernels for instance.

This is reasonably normal for a network with lots of narrow/index operations, though I’ve never seen the ratio quite that high. Custom kernels might help but torch.cat, which you’re probably using heavily, is already a very nice custom kernel.

1 Like

Hi, can you explain why index operations will slowdown backpropagation?
I have a similar question here: indexing-is-very-slow-for-backpropagation