Hi,
.data
is in the process of being removed and should not be used. As you have experienced, it only does very confusing things
You will need to have only nn.Parameter
s to be the true parameters and you will have to recompute other things at each forward:
import torch
import torch.nn as nn
import torch.optim as optim
class testModule(nn.Module):
def __init__(self):
super(testModule, self).__init__()
self.fc1 = nn.Linear(5, 10, bias=True)
self.fc2 = nn.Linear(10, 10, bias=False)
# Remove the weights as we override them in the forward
# so that they don't show up when calling .parameters()
del self.fc1.weight
del self.fc2.weight
self.fc2_base_weights = nn.Parameter(torch.randn(10, 10))
self.shared_weights = nn.Parameter(torch.randn(10, 5))
def forward(self, x):
# Update the weights
index = [1, 3, 5, 7, 9]
self.fc1.weight = self.shared_weights
self.fc2.weight = self.fc2_base_weights.clone()
self.fc2.weight[:, index] = self.shared_weights
x = self.fc1(x)
x = self.fc2(x)
return x
def _weight(self):