Suggestion for code optimisation

Hello everyone,

I’m creating a specific application in which there are 2 agents that send messages to each other. Each agent is modelled with a neural network designed using PyTorch. These agents generate predictions until the end of the episode. A possible interaction is the following:

  1. Agent 1: <message 1> -> Agent 2: <answer to message 1>
  2. Agent 1: <message 2> -> Agent 2: <answer to message 2>
  3. Agent 1: <message 3> -> Agent 2: <answer to message 3>

After this step, I have to collect the experiences and run an optimisation step for Agent 1 and a possible third agent that is used in the process (Agent 2 is not optimised). I’m trying to find a way to optimise my code in such a way that the phase in which the two agents have to generate messages over multiple turns is minimised. In my current implementation, the models are using batch-wise operations on GPU to process multiple examples at the same time. Unfortunately, I have some operations that are using pure Python. How do you suggest I should proceed in optimising this code? Do you have any specific optimisation suggestion about this specific design?