I am working with a recursive neural network where the forward pass takes roughly 2s on average, and the backward pass closer to 7 or 8s. Does this sound like normal behavior? I wonder what I could be doing which is causing such a slowdown.
I have a lot of narrow/chunk/cat in the model. Could this be a factor?
This is reasonably normal for a network with lots of narrow/index operations, though I’ve never seen the ratio quite that high. Custom kernels might help but torch.cat, which you’re probably using heavily, is already a very nice custom kernel.