For an RL task, I want to do parameter updates over minibatches of e.g. 100 or 500 episodes. All tensor operations in my code are unbatched (i.e. batch dimension, if any, is 1). This is because the RL controller is actually composed of multiple separate models, making it hard to program the forward pass as one chain of tensor operations.
In pseudocode, per minibatch, I do:
loss_terms = []
for i in range(100):
result_state = play_episode()
loss = mse_loss(result_state, target_state)
loss_terms.append(reward)
mean_loss = torch.mean(torch.stack(loss_terms))
optimizer.zero_grad()
mean_loss.backward()
optimizer.step()
The backwards pass is very slow, and in other topics I read that stack is a slow operation - but indexing is allegedly slow too, so I’m not sure pre-allocating a minibatch-sized tensor and index-filling it would help much. Is the computational graph simply too large and should I evaluate the gradient per episode instead? What’s the best practice on this?
Here’s the profiler output for 3 iterations of minibatch size 100. Everything is on CPU. (Also - cProfile worked fine, but autograd profiler completely flooded my RAM, to the extend that profiling with minibatch size 500 crashed. Does that make sense?)
--------------------------------------------------------------------------------
cProfile output
--------------------------------------------------------------------------------
23976946 function calls (22476653 primitive calls) in 80.041 seconds
Ordered by: internal time
List reduced from 2332 to 15 due to restriction <15>
ncalls tottime percall cumtime percall filename:lineno(function)
6 26.195 4.366 26.195 4.366 {method 'run_backward' of 'torch._C._EngineBase' objects}
744975 12.643 0.000 12.643 0.000 {method 'matmul' of 'torch._C._TensorBase' objects}
750393 3.830 0.000 19.936 0.000 \torch\nn\functional.py:1010(linear)
1541421/48762 3.787 0.000 45.183 0.001 \torch\nn\modules\module.py:471(__call__)
568890 3.180 0.000 3.180 0.000 {built-in method torch._C._nn.threshold}
750393 3.140 0.000 3.140 0.000 {method 't' of 'torch._C._TensorBase' objects}
455112 2.636 0.000 8.157 0.000 \ibp-pytorch\utilities.py:5(tensor_from)
455112 1.974 0.000 1.974 0.000 {built-in method cat}
744975 1.821 0.000 22.271 0.000 \torch\nn\modules\linear.py:54(forward)
43344 1.732 0.000 36.132 0.001 \ibp-pytorch\imaginator.py:71(<listcomp>)
176085 1.541 0.000 31.485 0.000 \torch\nn\modules\container.py:89(forward)
301 1.340 0.004 27.767 0.092 \ibp-pytorch\imagination_based_planner.py:153(new_episode)
43344 1.056 0.000 43.402 0.001 \ibp-pytorch\imaginator.py:70(forward)
308826 1.047 0.000 1.047 0.000 {built-in method tensor}
558054 0.968 0.000 0.968 0.000 {method 'float' of 'torch._C._TensorBase' objects}
--------------------------------------------------------------------------------
autograd profiler output (CPU mode)
--------------------------------------------------------------------------------
top 15 events sorted by cpu_time_total
--------------------- --------------- --------------- --------------- --------------- ---------------
Name CPU time CUDA time Calls CPU total CUDA total
--------------------- --------------- --------------- --------------- --------------- ---------------
matmul 10614.158us 0.000us 1 10614.158us 0.000us
mm 10587.491us 0.000us 1 10587.491us 0.000us
_mm 10585.851us 0.000us 1 10585.851us 0.000us
add_ 9396.106us 0.000us 1 9396.106us 0.000us
matmul 7469.952us 0.000us 1 7469.952us 0.000us
mm 7444.106us 0.000us 1 7444.106us 0.000us
_mm 7442.054us 0.000us 1 7442.054us 0.000us
MmBackward 7168.823us 0.000us 1 7168.823us 0.000us
matmul 7065.029us 0.000us 1 7065.029us 0.000us
stack 7030.157us 0.000us 1 7030.157us 0.000us
threshold 6931.695us 0.000us 1 6931.695us 0.000us
threshold_forward 6924.311us 0.000us 1 6924.311us 0.000us
stack 6900.516us 0.000us 1 6900.516us 0.000us
stack 6820.106us 0.000us 1 6820.106us 0.000us
mm 6572.720us 0.000us 1 6572.720us 0.000us```