Hi guys, I am trying to make a hypernetwork and got stuck.
z0 = torch.Tensor([1.,1.])
input_size = z0.shape[0]
mainNetwork = torch.nn.Sequential(
torch.nn.Linear(input_size,64),
torch.nn.Tanh(),
torch.nn.Linear(64,input_size))
p_shape,_,theta_init = get_parameters(mainNetwork)
input_size_theta = theta_init.shape[0]
hyperNetwork = torch.nn.Sequential(
torch.nn.Linear(input_size_theta,64),
torch.nn.Tanh(),
torch.nn.Linear(64,input_size_theta))
z0 = torch.tensor([1.,1.],requires_grad = True)
t = 0
theta_0 = hyperNetwork(theta_init)
set_parameters(mainNetwork,theta_0)
z1 = mainNetwork(z0)
torch.autograd.grad(z1,hyperNetwork.parameters(),grad_outputs = torch.ones_like(z1),allow_unused = True)
I get (None,None,None,None) as output.
get_parameters is a function that returns all the flattened parameters of a model into a single tensor.
set_parameters takes a single tensor and sets it as the model weights.
def get_parameters(model):
p_shape =
flat_parameters =
theta = torch.empty(0)
for p in model.parameters():
p_shape.append(p.size())
flat_parameters.append(p.flatten())
theta = torch.cat([theta,p.flatten()],dim = 0)
return p_shape,flat_parameters,theta
def set_parameters(model,theta):
p_shape,flat_parameters,_ = get_parameters(model)
idx = 0
j = 0
for i in range(len(model)):
if isinstance(model[i],torch.nn.Linear):
sub_theta_weight = theta[idx : idx + np.prod(p_shape[j])].reshape(p_shape[j])
sub_theta_bias = theta[idx + np.prod(p_shape[j]):idx + np.prod(p_shape[j]) + np.prod(p_shape[j+1])].reshape(p_shape[j+1])
model[i].weight = torch.nn.Parameter(sub_theta_weight)
model[i].bias = torch.nn.Parameter(sub_theta_bias)
idx = np.prod(p_shape[j]) + np.prod(p_shape[j+1])
j += 2
I also went through Hypernetwork implementation - #8 by mariaalfaroc
Apparently, even in these examples, torch.autograd.grad(x,hyperNetwork.parameters()) is giving the exact same error as in my case even though the loss.backward() is working fine.
Any help will be greatly appreciated. I am new to autograd and trying to learn more about how it works.
Thanks.