Hypernetwork implementation

Hello everyone,
How can I implement this example of Hypernetworks in Pytorch? I’m having problems with the backward pass.

If you could post a code snippet that could reproduce the issue, it will be easier for the community to help you.

Here it’s a toy example. I want hyper_net to predict the weights for main_net. The backward pass should only take into account the hypernetwork parameters.

import torch
import torch.nn as nn
import numpy as np

main_layers = [nn.Linear(64, 8), nn.Linear(8, 8), nn.Linear(8, 1)]
main_net = nn.Sequential(*main_layers)

def get_num_weights(model):
  weights = dict()
  num_weights = 0
  for name, param in model.named_parameters():
    if param.requires_grad:
      weights[name] = tuple(param.shape)
      num_weights += torch.prod(torch.tensor(param.shape)).item()
      param.requires_grad = False
  return weights, num_weights
w, nw = get_num_weights(main_net)

hyper_layers = [nn.Linear(64, 8), nn.Linear(8, 8), nn.Linear(8, nw)]
hyper_net = nn.Sequential(*hyper_layers)

def set_weights(weights_dict, model, pred_weights):
  index = 0
  for name, param in model.named_parameters():
    if name in weights_dict.keys():
      w_coeffs = pred_weights[:, index : index + torch.prod(torch.tensor(weights_dict[name])).item()]
      w_coeffs = w_coeffs.reshape(weights_dict[name])
      param.data = w_coeffs
      index += torch.prod(torch.tensor(weights_dict[name])).item()

optimizer = torch.optim.RMSprop(hyper_net.parameters(), lr=1e-3)
criterion = nn.MSELoss()

x = torch.rand(1, 64)
y = torch.rand(1, 1)


optimizer.zero_grad()
weights = hyper_net(x)
print(weights.shape)
set_weights(weights_dict=w, model=main_net, pred_weights=weights)
yhat = main_net(x)
loss = criterion(yhat, y)
loss.backward()
optimizer.step()

I get the following error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

If I don’t set requires_grad to False, it works fine. However, I don’t know if it is working as I want. The main network should not take part into the gradient computation, I only use to make predictions based on a set of predicted weights. It’s basically a frozen model. In keras, they use

for layer in main_network.layers:
    layer.built = True

However, I don’t know if we have any equivalent for the built attribute in Pytorch.

You are right! Sorry. Its my first time posting here. Also, please forgive my english!

No need to turn grad off for the main model, since the optimizer will update only the hypernet’s params. If you turn off gradients, your prediction / loss will also lack gradients, and thus fail.

Here is the second issue, and the main caveat: you can’t simply assign predictions of one model as a weight parameter of another model. Well you can, but that would detach model 1 (hypernetwork) from the computation graph. Consider the following:

import torch
from torch import nn, optim

A = nn.Linear(1, 1) # main network
B = nn.Linear(1, 2) # hypernetwork <-- need to optimize this

def foo(a, b, c):
    print("B Forward")

def bar(a, b, c):
    print("B Backward")

B.register_forward_hook(foo)
B.register_full_backward_hook(bar)

optimizer = optim.Adam(B.parameters(), lr=1e-3)

X = torch.randn(1)

optimizer.zero_grad()

W, b = B(X)
A.weight.data = W.reshape(1, 1)
B.bias.data = b

pred = A(X)
loss = pred.sum()
loss.backward()
optimizer.step()

If we run this we will get:

B Forward

“B Backward” isn’t printing. B isn’t getting any gradients!

The workaround for this could be by simply using the backend nn.functional layers. Instead of nn.Linear()(x), we will have to use F.linear(x, W, b) to retain the gradients.

It is pretty simple to just convert the toy keras example you linked with only a few changes:

import torch
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
import tensorflow as tf

def main_network(x, params):
    w0, b0 = params[0]
    w1, b1 = params[1]

    x = F.relu(F.linear(x, w0, b0))
    x = F.linear(x, w1, b1)
    return x

input_dim = 784
classes = 10

layer_shapes = [
    [(64, input_dim), (64,)],   # (w0, b0)
    [(classes, 64), (classes,)] # (w1, b1)
]

num_weights_to_generate = (classes * 64 + classes) + (64 * input_dim + 64)

hypernetwork = nn.Sequential(
    nn.Linear(input_dim, 16),
    nn.ReLU(),
    nn.Linear(16, num_weights_to_generate),
    nn.Sigmoid()
)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(hypernetwork.parameters(), lr=1e-4)

def train_step(x, y):
    optimizer.zero_grad()
    weights_pred = hypernetwork(x)
 
    idx, params = 0, []
    for layer in layer_shapes:
        layer_params = []
        for shape in layer:
            offset = np.prod(shape)
            layer_params.append(weights_pred[:, idx: idx + offset].reshape(shape))
            idx += offset
        params.append(layer_params)

    preds = main_network(x, params)
    loss = loss_fn(preds, y)
    loss.backward()
    optimizer.step()
    return loss.item()

You can test it with keras’s dataset too, just convert those tf tensors to torch tensors via numpy.

(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train.reshape(60000, 784).astype("float32") / 255, y_train)
)
dataset = dataset.shuffle(buffer_size=1024).batch(1)

losses = []

torchify = lambda a: torch.from_numpy(a.numpy())

for step, (x, y) in enumerate(dataset):
    x, y = torchify(x).float(), torchify(y).long()
    loss = train_step(x, y)
    losses.append(loss)
    if step % 1 == 0: print("Step:", step, "Loss:", sum(losses) / len(losses))
    if step >= 1000: break

Hope this example helps!

2 Likes

This example is amazing! You absolutely solve my problem.

Sorry for the late reply! Work and life got in the way.
I have just replicated your answer and also tried it on a bigger model, and everything works perfectly! I cannot thank you enough :slight_smile:

1 Like

Hello!
I have some doubts…what if I want to do a hypernetwork that predicts the weights for a LSTM layer? I’ve been looking for the functional version of LSTM, but there isn’t one. Should I use this one?

from torch import _VF
_VF.lstm()

@ID56 I’ve been thinking about this again. I was wondering if then, using torch.nn.Functional is the only way to implement hypernetworks. Is there a way to make it work using torch.nn.Parameter? If I try to do so using your example:

import torch
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
import tensorflow as tf

input_dim = 784
classes = 10

layer_shapes = [
    [(64, input_dim), (64,)],   # (w0, b0)
    [(classes, 64), (classes,)] # (w1, b1)
]

num_weights_to_generate = (classes * 64 + classes) + (64 * input_dim + 64)

hypernetwork = nn.Sequential(
    nn.Linear(input_dim, 16),
    nn.ReLU(),
    nn.Linear(16, num_weights_to_generate),
    nn.Sigmoid()
)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(hypernetwork.parameters(), lr=1e-4)

class MainNet(torch.nn.Module):
    def __init__(self):
        super(MainNet, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, 64)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(64, classes)

    def forward(self, x):
        with torch.no_grad():
            x = self.linear1(x)
            x = self.relu(x)
            x = self.linear2(x)
            return x

    def set_weights(self, params):
        with torch.no_grad():
            w0, b0 = params[0]
            w1, b1 = params[1]
            self.linear1.weight = torch.nn.Parameter(w0, requires_grad=False)
            self.linear1.bias = torch.nn.Parameter(b0, requires_grad=False)
            self.linear2.weight = torch.nn.Parameter(w1, requires_grad=False)
            self.linear2.bias = torch.nn.Parameter(b1, requires_grad=False)

main_network = MainNet()

def train_step(x, y):
    optimizer.zero_grad()
    weights_pred = hypernetwork(x)
 
    idx, params = 0, []
    for layer in layer_shapes:
        layer_params = []
        for shape in layer:
            offset = np.prod(shape)
            layer_params.append(weights_pred[:, idx: idx + offset].reshape(shape))
            idx += offset
        params.append(layer_params)

    main_network.set_weights(params)
    preds = main_network(x)
    loss = loss_fn(preds, y)
    loss.backward(retain_graph=True)
    optimizer.step()
    return loss.item()

(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train.reshape(60000, 784).astype("float32") / 255, y_train)
)
dataset = dataset.shuffle(buffer_size=1024).batch(1)

losses = []

torchify = lambda a: torch.from_numpy(a.numpy())

for step, (x, y) in enumerate(dataset):
    x, y = torchify(x).float(), torchify(y).long()
    loss = train_step(x, y)
    losses.append(loss)
    if step % 1 == 0: print("Step:", step, "Loss:", sum(losses) / len(losses))
    if step >= 1000: break

If I remove all the torch.no_grad() and set requires_grad = True, it works. But I’m worried that I’m doing something wrong, because the gradients for MainNetwork are also being computed. Even though I think this is not a problem since those new weights that backprop computes are never used since they are overwrite by Hypernetwork. So, the only problem would be that it would take longer to optimize the Hypernetwork.

I’m back again hahah In Pytorch 1.12.0, there exists torch.nn.utils.stateless.functional_call, which I recently discovered thanks to @albanD :grin: (see: Functional version of torch.nn.LSTM -> torch.nn.functional.LSTM · Issue #80454 · pytorch/pytorch · GitHub).

That way we can create the main network as another module but call it in a functional way. This is great for using, for example, LSTM which do not have a functional layer per se. So the above example could be solved in the following way:

import torch
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
import tensorflow as tf

input_dim = 784
classes = 10

layer_shapes = [
    [(64, input_dim), (64,)],   # (w0, b0)
    [(classes, 64), (classes,)] # (w1, b1)
]

num_weights_to_generate = (classes * 64 + classes) + (64 * input_dim + 64)

hypernetwork = nn.Sequential(
    nn.Linear(input_dim, 16),
    nn.ReLU(),
    nn.Linear(16, num_weights_to_generate),
    nn.Sigmoid()
)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(hypernetwork.parameters(), lr=1e-4)

class MainNet(torch.nn.Module):
    def __init__(self):
        super(MainNet, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, 64)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(64, classes)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

main_network = MainNet()
for param in main_network.parameters():
    param.requires_grad = False

def format_weights(params):
    w0, b0 = params[0]
    w1, b1 = params[1]
    weights = {
        "linear1.weight": w0,
        "linear1.bias": b0,
        "linear2.weight": w1,
        "linear2.bias": b1,
    }
    return weights

def train_step(x, y):
    optimizer.zero_grad()
    weights_pred = hypernetwork(x)
 
    idx, params = 0, []
    for layer in layer_shapes:
        layer_params = []
        for shape in layer:
            offset = np.prod(shape)
            layer_params.append(weights_pred[:, idx: idx + offset].reshape(shape))
            idx += offset
        params.append(layer_params)

    weights = format_weights(params)
    preds = torch.nn.utils.stateless.functional_call(main_network, weights, x)
    loss = loss_fn(preds, y)
    loss.backward(retain_graph=True)
    optimizer.step()
    return loss.item()



(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train.reshape(60000, 784).astype("float32") / 255, y_train)
)
dataset = dataset.shuffle(buffer_size=1024).batch(1)

losses = []

torchify = lambda a: torch.from_numpy(a.numpy())

for step, (x, y) in enumerate(dataset):
    x, y = torchify(x).float(), torchify(y).long()
    loss, preds = train_step(x, y)
    losses.append(loss)
    if step % 1 == 0: print("Step:", step, "Loss:", sum(losses) / len(losses))
    if step >= 1000: break

Hope this helps!

1 Like