Hypernetwork implementation

@ID56 I’ve been thinking about this again. I was wondering if then, using torch.nn.Functional is the only way to implement hypernetworks. Is there a way to make it work using torch.nn.Parameter? If I try to do so using your example:

import torch
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
import tensorflow as tf

input_dim = 784
classes = 10

layer_shapes = [
    [(64, input_dim), (64,)],   # (w0, b0)
    [(classes, 64), (classes,)] # (w1, b1)
]

num_weights_to_generate = (classes * 64 + classes) + (64 * input_dim + 64)

hypernetwork = nn.Sequential(
    nn.Linear(input_dim, 16),
    nn.ReLU(),
    nn.Linear(16, num_weights_to_generate),
    nn.Sigmoid()
)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(hypernetwork.parameters(), lr=1e-4)

class MainNet(torch.nn.Module):
    def __init__(self):
        super(MainNet, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, 64)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(64, classes)

    def forward(self, x):
        with torch.no_grad():
            x = self.linear1(x)
            x = self.relu(x)
            x = self.linear2(x)
            return x

    def set_weights(self, params):
        with torch.no_grad():
            w0, b0 = params[0]
            w1, b1 = params[1]
            self.linear1.weight = torch.nn.Parameter(w0, requires_grad=False)
            self.linear1.bias = torch.nn.Parameter(b0, requires_grad=False)
            self.linear2.weight = torch.nn.Parameter(w1, requires_grad=False)
            self.linear2.bias = torch.nn.Parameter(b1, requires_grad=False)

main_network = MainNet()

def train_step(x, y):
    optimizer.zero_grad()
    weights_pred = hypernetwork(x)
 
    idx, params = 0, []
    for layer in layer_shapes:
        layer_params = []
        for shape in layer:
            offset = np.prod(shape)
            layer_params.append(weights_pred[:, idx: idx + offset].reshape(shape))
            idx += offset
        params.append(layer_params)

    main_network.set_weights(params)
    preds = main_network(x)
    loss = loss_fn(preds, y)
    loss.backward(retain_graph=True)
    optimizer.step()
    return loss.item()

(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train.reshape(60000, 784).astype("float32") / 255, y_train)
)
dataset = dataset.shuffle(buffer_size=1024).batch(1)

losses = []

torchify = lambda a: torch.from_numpy(a.numpy())

for step, (x, y) in enumerate(dataset):
    x, y = torchify(x).float(), torchify(y).long()
    loss = train_step(x, y)
    losses.append(loss)
    if step % 1 == 0: print("Step:", step, "Loss:", sum(losses) / len(losses))
    if step >= 1000: break

If I remove all the torch.no_grad() and set requires_grad = True, it works. But I’m worried that I’m doing something wrong, because the gradients for MainNetwork are also being computed. Even though I think this is not a problem since those new weights that backprop computes are never used since they are overwrite by Hypernetwork. So, the only problem would be that it would take longer to optimize the Hypernetwork.