[ERROR] Problem in updating batched variable

Hey all,
I am working in RL where I am using A2C algorithm to train an agent to follow instructions. I am spawning multiple environments simultaneously and batching outputs of the environment to improve efficiency. I am attaching the code for model training Here.

But the above code is giving this error. and the model is defined here

I figured out the problem is due to updation of a particular hx and cx.The hx and cs are outputs of model(inputs), which depend on the scene’s/agent’s history. So if any episode starts, hx[i] and cx[i] in i^th environment, we have to zero it.This causes problems.

I would like to know how to solve this. Thanks and looking forward to your input.Please leet me know if you need more clarifications.

This context manager can be used to work around that error Automatic differentiation package - torch.autograd — PyTorch 2.2 documentation. Let me know if that works for you.

Thanks for the reply @soulitzer , will try it and report back.