Diagnosing slow backward pass with RL gradient over minibatch

For an RL task, I want to do parameter updates over minibatches of e.g. 100 or 500 episodes. All tensor operations in my code are unbatched (i.e. batch dimension, if any, is 1). This is because the RL controller is actually composed of multiple separate models, making it hard to program the forward pass as one chain of tensor operations.

In pseudocode, per minibatch, I do:

loss_terms = []

for i in range(100):
     result_state = play_episode()
     loss = mse_loss(result_state, target_state)
     loss_terms.append(reward)

mean_loss = torch.mean(torch.stack(loss_terms))
optimizer.zero_grad()
mean_loss.backward()
optimizer.step()

The backwards pass is very slow, and in other topics I read that stack is a slow operation - but indexing is allegedly slow too, so I’m not sure pre-allocating a minibatch-sized tensor and index-filling it would help much. Is the computational graph simply too large and should I evaluate the gradient per episode instead? What’s the best practice on this?

Here’s the profiler output for 3 iterations of minibatch size 100. Everything is on CPU. (Also - cProfile worked fine, but autograd profiler completely flooded my RAM, to the extend that profiling with minibatch size 500 crashed. Does that make sense?)

--------------------------------------------------------------------------------
  cProfile output
--------------------------------------------------------------------------------
         23976946 function calls (22476653 primitive calls) in 80.041 seconds

   Ordered by: internal time
   List reduced from 2332 to 15 due to restriction <15>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        6   26.195    4.366   26.195    4.366 {method 'run_backward' of 'torch._C._EngineBase' objects}
   744975   12.643    0.000   12.643    0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
   750393    3.830    0.000   19.936    0.000 \torch\nn\functional.py:1010(linear)
1541421/48762    3.787    0.000   45.183    0.001 \torch\nn\modules\module.py:471(__call__)
   568890    3.180    0.000    3.180    0.000 {built-in method torch._C._nn.threshold}
   750393    3.140    0.000    3.140    0.000 {method 't' of 'torch._C._TensorBase' objects}
   455112    2.636    0.000    8.157    0.000 \ibp-pytorch\utilities.py:5(tensor_from)
   455112    1.974    0.000    1.974    0.000 {built-in method cat}
   744975    1.821    0.000   22.271    0.000 \torch\nn\modules\linear.py:54(forward)
    43344    1.732    0.000   36.132    0.001 \ibp-pytorch\imaginator.py:71(<listcomp>)
   176085    1.541    0.000   31.485    0.000 \torch\nn\modules\container.py:89(forward)
      301    1.340    0.004   27.767    0.092 \ibp-pytorch\imagination_based_planner.py:153(new_episode)
    43344    1.056    0.000   43.402    0.001 \ibp-pytorch\imaginator.py:70(forward)
   308826    1.047    0.000    1.047    0.000 {built-in method tensor}
   558054    0.968    0.000    0.968    0.000 {method 'float' of 'torch._C._TensorBase' objects}


--------------------------------------------------------------------------------
  autograd profiler output (CPU mode)
--------------------------------------------------------------------------------
        top 15 events sorted by cpu_time_total

---------------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                          CPU time        CUDA time            Calls        CPU total       CUDA total
---------------------  ---------------  ---------------  ---------------  ---------------  ---------------
matmul                     10614.158us          0.000us                1      10614.158us          0.000us
mm                         10587.491us          0.000us                1      10587.491us          0.000us
_mm                        10585.851us          0.000us                1      10585.851us          0.000us
add_                        9396.106us          0.000us                1       9396.106us          0.000us
matmul                      7469.952us          0.000us                1       7469.952us          0.000us
mm                          7444.106us          0.000us                1       7444.106us          0.000us
_mm                         7442.054us          0.000us                1       7442.054us          0.000us
MmBackward                  7168.823us          0.000us                1       7168.823us          0.000us
matmul                      7065.029us          0.000us                1       7065.029us          0.000us
stack                       7030.157us          0.000us                1       7030.157us          0.000us
threshold                   6931.695us          0.000us                1       6931.695us          0.000us
threshold_forward           6924.311us          0.000us                1       6924.311us          0.000us
stack                       6900.516us          0.000us                1       6900.516us          0.000us
stack                       6820.106us          0.000us                1       6820.106us          0.000us
mm                          6572.720us          0.000us                1       6572.720us          0.000us```
  1. why don’t you directly sum the losses instead of stacking them into a tensor ?
loss_sum = 0

for i in range(100):
     result_state = play_episode()
     loss = mse_loss(result_state, target_state)
     loss_sum += reward

optimizer.zero_grad()
loss_sum.backward()
optimizer.step()
  1. we need more information about your code (maybe somewhere your computation graph could be alleviated for the backward pass)

@Driesssens did you finally find the solution to your problem? I kind of have the same issue…

Not really. I tried some other things like directly summing the losses (and keeping track of the amount of summed items to be able to take the average at the end) but none was clearly faster than the others. In the end it turned out that much smaller batches (20-50) sped up learning for my problem significantly, alleviating the problem.