Problem with Hypernetwork in combination with TorchRL

I am currently writing my bachelor thesis. I’m currently working on a Hypernetwork that should work together in a combination, for example PPO and SAC. In doing so, I encountered a problem.

We have a TargetNetwork class and a Hypernetwork class both based on the nn.module.
The target is similar to an MLP. Except that it has a setWeights method that takes a list of tensors or a parameterlist.

  1. First after the collector a step in the env and returns it stats as a tensordict. Then it is sampled with a replay buffer.

  2. In the case of a forward pass of the Hypernetwork, the network generates the weights for the Targetnetwork.

actor_weights = hnet_actor.forward(task_idx)
tnet_actor.operator.set_weights(actor_weights)
  1. Then the loss module, for example PPOLoss, should calculate the error and back propagate the gradient.
subdata = replay_buffer.sample()  
loss_vals = loss_module(subdata.to(device))  
loss_value = (loss_vals["loss_objective"]  + loss_vals["loss_critic"]  + loss_vals["loss_entropy"])  
loss_value.backward()

Since the graph is included with the weights. The gradient of Hypernetwork should also be adjusted.
4. The optimizer then optimizes the Hypernetwork.
5. And the next passage then repeats the steps again.

Now to the problem As far as I have understood the code of different loss modules and the collector, the actor and critic are converted into a functional variant. Therefore it no longer works that I simply exchange the weights because the loss module and the collector do not accept the change in their functional variant.

Is there a way to set the weights? The weights should keep the graph of the Hypernetwork.

Kind regards, Fabian

Thanks for reporting this.
What I understand is that your actor/critic is stateless, am I right?
In that case, I think you should pass your parameters within the input tensordict.

subdata = replay_buffer.sample()  
actor_weights = hnet_actor.forward(task_idx)
curr_data = TensorDict({"subdata": subdata, "actor_weights": actor_weights}, batch_size=[])

Now we have packed the weights and data in a single tensordict.
I’m assuming that actor_weights is a dictionary here.
Then you need an actor that does something like:

class Actor(TensorDictModuleBase):
    in_keys = [...]
    out_keys = [...]
    def forward(self, tensordict):
        actor_weights = tensordict.get("actor_weights")
        subdata = tensordict.get("subdata")
        self.tnet_actor.operator.set_weights(actor_weights)
        tensordict = self.tnet_actor(subdata)
        return tensordict

It seems a bit convoluted in this case, but you’ll get to pass your parameters and data in one go to the model.
Would that solve the issue?

Do not hesitate to reach out if this is not clear or if you need more info!

What I understand is that your actor/critic is stateless, am I right?
Now if we only talk about the actor. It is not stateless. Only in the loss module (ppo) is it converted into a functional variant.

The idea (Tensordict) is actually nice, but other problems arise. The biggest problem is that actors has parameters and there is no official way to exchange these parameters with external tensors.
Whereby the graph remains.

I have found a not nice but working way.
I exchange the parameters after each forward of the Hypernet in the LossModule (eg ppo) and in the ActorModel with a NamedMemberAccessor (only for actor) and no longer with setWeights. The graph is retained when the parameters are set. And with backwards, the gradient in the hypernetwork is then also adjusted. Why both, if I only exchange the parameters in the module, the loss module still has the old ones because we exchange the parameters and not the data of the parameters. That’s why I have to change the tensor in the loss module as well.

i exchange these parameters in the loss module, for example:

actor_sep_module_sep_0_sep_module_sep_operator_sep_weights_sep_0                                                                                                                                                                                                                             
actor_sep_module_sep_0_sep_module_sep_operator_sep_weights_sep_1
actor_sep_module_sep_0_sep_module_sep_operator_sep_weights_sep_2
actor_sep_module_sep_0_sep_module_sep_operator_sep_weights_sep_3
actor_sep_module_sep_0_sep_module_sep_operator_sep_weights_sep_4
actor_sep_module_sep_0_sep_module_sep_operator_sep_weights_sep_5

With your solution there is also the problem that the collector also forwards internally for the actor network and here you cannot add the weights in the tensor as far as I have seen.

I think I get it

Would this work?

from tensordict.nn import TensorDictModule as Mod, NormalParamExtractor
from torchrl.modules import ProbabilisticActor as Act, TanhNormal
from torchrl.objectives import ClipPPOLoss as PPO
from torchrl.objectives.value import GAE
import torch
from torch import nn
from tensordict import TensorDict

module = Mod(nn.Sequential(nn.Linear(3, 4), NormalParamExtractor()), in_keys=["obs"], out_keys=["loc", "scale"])
actor = ProbabilisticActor(module, in_keys=["loc", "scale"], distribution_class=TanhNormal)
critic = Mod(nn.Linear(3, 1), in_keys=["obs"], out_keys=["state_value"])
gae = GAE(gamma=0.99, lmbda=0.95, value_network=critic)
ppo = PPO(actor, critic)

from tensordict.nn import make_functional

params = make_functional(ppo)

data = TensorDict({
    "obs": torch.randn(3),
    "action": torch.randn(2),
    "sample_log_prob": torch.randn(()),
    ("next", "done"): torch.zeros(1, dtype=torch.bool),
    ("next", "reward"): torch.zeros(1),
    ("obs"): torch.randn(3),
}, []).expand(10).contiguous()

gae(data)

ppo(data.view(-1), params)

I’m making the entire PPO functional, ripping off all the parameters and buffers.
That way you have full control over what they do

Note that the params that are displayed are a bit messy, you’ll have a critic and an actor entry complitely void and some strange looking params with _sep_ in the middle as this is how LossModule stored the params.
There is defo a bit of improvement for us to facilitate functional calls to LossModules :slightly_smiling_face:

But at least with this you have access to the params of the actor and critic and can play with them.

Would that solve your issue?

I’ll take a look at the suggestion. I’m busy with something else right now. Otherwise I’ll write again.

What I saw quickly, however, is the problem with the collector. The collector does not accept any parameters (params) for the rollout.