Parameter defined by torch.cat does not get updated

I have two models, model_1 and model_2, which share a torch.nn.Parameter model_1.emb of shape [n,dim]. This parameter is defined in the init of model_1, and passed to the constructor of model_2 to define

self.shared_emb = emb

This shared parameter is used in the forward functions of both models, and it is properly updated by each of the two optimizers, which optimize a shared loss function

L = L_model_1 + L_model_2

This works fine.

Now, in model_2, I want to concatenate self.shared_emb with a second parameter. To verify that this would work, I use a dummy parameter model_2.dum with shape [0,dim]. I use the following code in the init of model_2:

self.cat_emb = torch.cat([self.shared_emb, self.dum], dim=0)

Now, model_2.shared_emb should be exactly the same thing as model_2.cat_emb . I replace the former with the latter in model_2’s forward function.
But for some reason, this new parameter does not get updated, its weights do not change!

I googled, and also tried the following:

self.cat_emb = torch.cat([emb, dum], dim=0)
self.cat_emb.retain_grad()

but still, the parameter thus defined does not get updated. requires_grad is set to True for both emb and shared_emb. I have triple checked everything and am at my wit’s end here.

Any ideas why this does not work?

self.cat_emb is created via an operation and is thus not a leaf variable anymore.
The optimizer would thus raise an error, as it cannot optimizer non-leaf tensors, if you try to pass self.cat_emb to it.

And what does that imply? Is there no way to make this work?

You won’t be able to update cat_emb, but would need to recreate it using shared_emb (which can be updated).