Error in Shared Layers

Hi there.

I am trying to implement two simple models, in which their last connected layer is shared. I have seen this post Sharing Layers Discussion and tried to run the code provided by the original poster, however it gave an error.

Then, I tried to implement the thing myself with the code below. (First layer for both models are exclusive and the last layer must be shared.)

import torch
from torch import nn, optim
import torch.nn.functional as F


class A(nn.Module):
   def __init__(self):
       super(A, self).__init__()
       self.base_fc_a = nn.Linear(3,3)
       self.fc = nn.Linear(3, 1)

   def forward(self, x):
       x2 = self.base_fc_a(x)       
       x3 = self.fc(x2)
       return x3

class B(nn.Module):
   def __init__(self, shared_re):
       super(B, self).__init__()
       self.base_fc_b = nn.Linear(3,3)
       self.shared_fc = shared_re
   def forward(self, x):
       x2 = self.base_fc_b(x)
       x3 = self.shared_fc(x2)
       return x3
net_A = A()
net_B = B(shared_re = net_A.fc)
optim_A = optim.Adam(net_A.parameters())
optim_B = optim.Adam(net_B.parameters())
target = torch.randn(1,1)

x_A = torch.rand(1, 3)
y_A_hat = net_A(x_A)
loss_A = F.mse_loss(y_A_hat, target)

x_B = torch.rand(1, 3)
y_B_hat = net_B(x_B)
loss_B = F.mse_loss(y_B_hat, target)



Well, this gives an error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

However, when I do the loss calculation for loss_B after loss_A.backward, there is no issue. What causes this? Any help or explanation is appreciated.

.step() will update the model weight ‘inplace’.
You can try to do .step() after second .backward(), or redo the forward process with new model weight after first .step().
Hope this help.