No need to turn grad off for the main model, since the optimizer will update only the hypernet’s params. If you turn off gradients, your prediction / loss will also lack gradients, and thus fail.
Here is the second issue, and the main caveat: you can’t simply assign predictions of one model as a weight parameter of another model. Well you can, but that would detach model 1 (hypernetwork) from the computation graph. Consider the following:
import torch
from torch import nn, optim
A = nn.Linear(1, 1) # main network
B = nn.Linear(1, 2) # hypernetwork <-- need to optimize this
def foo(a, b, c):
print("B Forward")
def bar(a, b, c):
print("B Backward")
B.register_forward_hook(foo)
B.register_full_backward_hook(bar)
optimizer = optim.Adam(B.parameters(), lr=1e-3)
X = torch.randn(1)
optimizer.zero_grad()
W, b = B(X)
A.weight.data = W.reshape(1, 1)
B.bias.data = b
pred = A(X)
loss = pred.sum()
loss.backward()
optimizer.step()
If we run this we will get:
B Forward
“B Backward” isn’t printing. B isn’t getting any gradients!
The workaround for this could be by simply using the backend nn.functional
layers. Instead of nn.Linear()(x)
, we will have to use F.linear(x, W, b)
to retain the gradients.
It is pretty simple to just convert the toy keras example you linked with only a few changes:
import torch
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
import tensorflow as tf
def main_network(x, params):
w0, b0 = params[0]
w1, b1 = params[1]
x = F.relu(F.linear(x, w0, b0))
x = F.linear(x, w1, b1)
return x
input_dim = 784
classes = 10
layer_shapes = [
[(64, input_dim), (64,)], # (w0, b0)
[(classes, 64), (classes,)] # (w1, b1)
]
num_weights_to_generate = (classes * 64 + classes) + (64 * input_dim + 64)
hypernetwork = nn.Sequential(
nn.Linear(input_dim, 16),
nn.ReLU(),
nn.Linear(16, num_weights_to_generate),
nn.Sigmoid()
)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(hypernetwork.parameters(), lr=1e-4)
def train_step(x, y):
optimizer.zero_grad()
weights_pred = hypernetwork(x)
idx, params = 0, []
for layer in layer_shapes:
layer_params = []
for shape in layer:
offset = np.prod(shape)
layer_params.append(weights_pred[:, idx: idx + offset].reshape(shape))
idx += offset
params.append(layer_params)
preds = main_network(x, params)
loss = loss_fn(preds, y)
loss.backward()
optimizer.step()
return loss.item()
You can test it with keras’s dataset too, just convert those tf tensors to torch tensors via numpy.
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
(x_train.reshape(60000, 784).astype("float32") / 255, y_train)
)
dataset = dataset.shuffle(buffer_size=1024).batch(1)
losses = []
torchify = lambda a: torch.from_numpy(a.numpy())
for step, (x, y) in enumerate(dataset):
x, y = torchify(x).float(), torchify(y).long()
loss = train_step(x, y)
losses.append(loss)
if step % 1 == 0: print("Step:", step, "Loss:", sum(losses) / len(losses))
if step >= 1000: break
Hope this example helps!