Hypernetwork implementation

No need to turn grad off for the main model, since the optimizer will update only the hypernet’s params. If you turn off gradients, your prediction / loss will also lack gradients, and thus fail.

Here is the second issue, and the main caveat: you can’t simply assign predictions of one model as a weight parameter of another model. Well you can, but that would detach model 1 (hypernetwork) from the computation graph. Consider the following:

import torch
from torch import nn, optim

A = nn.Linear(1, 1) # main network
B = nn.Linear(1, 2) # hypernetwork <-- need to optimize this

def foo(a, b, c):
    print("B Forward")

def bar(a, b, c):
    print("B Backward")

B.register_forward_hook(foo)
B.register_full_backward_hook(bar)

optimizer = optim.Adam(B.parameters(), lr=1e-3)

X = torch.randn(1)

optimizer.zero_grad()

W, b = B(X)
A.weight.data = W.reshape(1, 1)
B.bias.data = b

pred = A(X)
loss = pred.sum()
loss.backward()
optimizer.step()

If we run this we will get:

B Forward

“B Backward” isn’t printing. B isn’t getting any gradients!

The workaround for this could be by simply using the backend nn.functional layers. Instead of nn.Linear()(x), we will have to use F.linear(x, W, b) to retain the gradients.

It is pretty simple to just convert the toy keras example you linked with only a few changes:

import torch
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
import tensorflow as tf

def main_network(x, params):
    w0, b0 = params[0]
    w1, b1 = params[1]

    x = F.relu(F.linear(x, w0, b0))
    x = F.linear(x, w1, b1)
    return x

input_dim = 784
classes = 10

layer_shapes = [
    [(64, input_dim), (64,)],   # (w0, b0)
    [(classes, 64), (classes,)] # (w1, b1)
]

num_weights_to_generate = (classes * 64 + classes) + (64 * input_dim + 64)

hypernetwork = nn.Sequential(
    nn.Linear(input_dim, 16),
    nn.ReLU(),
    nn.Linear(16, num_weights_to_generate),
    nn.Sigmoid()
)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(hypernetwork.parameters(), lr=1e-4)

def train_step(x, y):
    optimizer.zero_grad()
    weights_pred = hypernetwork(x)
 
    idx, params = 0, []
    for layer in layer_shapes:
        layer_params = []
        for shape in layer:
            offset = np.prod(shape)
            layer_params.append(weights_pred[:, idx: idx + offset].reshape(shape))
            idx += offset
        params.append(layer_params)

    preds = main_network(x, params)
    loss = loss_fn(preds, y)
    loss.backward()
    optimizer.step()
    return loss.item()

You can test it with keras’s dataset too, just convert those tf tensors to torch tensors via numpy.

(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train.reshape(60000, 784).astype("float32") / 255, y_train)
)
dataset = dataset.shuffle(buffer_size=1024).batch(1)

losses = []

torchify = lambda a: torch.from_numpy(a.numpy())

for step, (x, y) in enumerate(dataset):
    x, y = torchify(x).float(), torchify(y).long()
    loss = train_step(x, y)
    losses.append(loss)
    if step % 1 == 0: print("Step:", step, "Loss:", sum(losses) / len(losses))
    if step >= 1000: break

Hope this example helps!

2 Likes