@ID56 I’ve been thinking about this again. I was wondering if then, using torch.nn.Functional
is the only way to implement hypernetworks. Is there a way to make it work using torch.nn.Parameter
? If I try to do so using your example:
import torch
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
import tensorflow as tf
input_dim = 784
classes = 10
layer_shapes = [
[(64, input_dim), (64,)], # (w0, b0)
[(classes, 64), (classes,)] # (w1, b1)
]
num_weights_to_generate = (classes * 64 + classes) + (64 * input_dim + 64)
hypernetwork = nn.Sequential(
nn.Linear(input_dim, 16),
nn.ReLU(),
nn.Linear(16, num_weights_to_generate),
nn.Sigmoid()
)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(hypernetwork.parameters(), lr=1e-4)
class MainNet(torch.nn.Module):
def __init__(self):
super(MainNet, self).__init__()
self.linear1 = torch.nn.Linear(input_dim, 64)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(64, classes)
def forward(self, x):
with torch.no_grad():
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
def set_weights(self, params):
with torch.no_grad():
w0, b0 = params[0]
w1, b1 = params[1]
self.linear1.weight = torch.nn.Parameter(w0, requires_grad=False)
self.linear1.bias = torch.nn.Parameter(b0, requires_grad=False)
self.linear2.weight = torch.nn.Parameter(w1, requires_grad=False)
self.linear2.bias = torch.nn.Parameter(b1, requires_grad=False)
main_network = MainNet()
def train_step(x, y):
optimizer.zero_grad()
weights_pred = hypernetwork(x)
idx, params = 0, []
for layer in layer_shapes:
layer_params = []
for shape in layer:
offset = np.prod(shape)
layer_params.append(weights_pred[:, idx: idx + offset].reshape(shape))
idx += offset
params.append(layer_params)
main_network.set_weights(params)
preds = main_network(x)
loss = loss_fn(preds, y)
loss.backward(retain_graph=True)
optimizer.step()
return loss.item()
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
dataset = tf.data.Dataset.from_tensor_slices(
(x_train.reshape(60000, 784).astype("float32") / 255, y_train)
)
dataset = dataset.shuffle(buffer_size=1024).batch(1)
losses = []
torchify = lambda a: torch.from_numpy(a.numpy())
for step, (x, y) in enumerate(dataset):
x, y = torchify(x).float(), torchify(y).long()
loss = train_step(x, y)
losses.append(loss)
if step % 1 == 0: print("Step:", step, "Loss:", sum(losses) / len(losses))
if step >= 1000: break
If I remove all the torch.no_grad()
and set requires_grad = True
, it works. But I’m worried that I’m doing something wrong, because the gradients for MainNetwork are also being computed. Even though I think this is not a problem since those new weights that backprop computes are never used since they are overwrite by Hypernetwork. So, the only problem would be that it would take longer to optimize the Hypernetwork.