So I want to train a meta network whose output is the weights of another network.
For example, suppose generated_weights
denotes the weights generated by the meta network, I tried using the following code to inject those weights into another network:
for name, param in backbone.named_parameters():
param.data.copy_(generated_weights[name][0])
However, when I was using this to train and printed out the weights, I observed that the weights didn’t change at all, which means the backprop or the optimizer didn’t work for some reason, and I suspect that the operation .copy_
doesn’t preserve the computational graph?
I have no idea how to accomplish this task. Really appreciate any help.
train.py
:
import tqdm
import math
import torch
import numpy as np
from torch import nn
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from model.backbone import Backbone
from model.metanet import MetaNet
from torch.func import functional_call
def gen_data(task='A'):
N = 1000
x_min, x_max = -4, 4
y_min, y_max = -4, 4
resolution = 100 # Resolution of the grid
# Create the grid
x = np.linspace(x_min, x_max, resolution)
y = np.linspace(y_min, y_max, resolution)
X, Y = np.meshgrid(x, y)
# Checkerboard pattern
length = 4
if task == 'A':
checkerboard = np.indices((length, length)).sum(axis=0) % 2
else:
checkerboard = np.indices((length, length)).sum(axis=0) % 3
# Sample points in regions where checkerboard pattern is 1
sampled_points = []
while len(sampled_points) < N:
# Randomly sample a point within the x and y range
x_sample = np.random.uniform(x_min, x_max)
y_sample = np.random.uniform(y_min, y_max)
# Determine the closest grid index
i = int((x_sample - x_min) / (x_max - x_min) * length)
j = int((y_sample - y_min) / (y_max - y_min) * length)
# Check if the sampled point is in a region where checkerboard == 1
if checkerboard[j, i] == 1:
sampled_points.append((x_sample, y_sample))
return sampled_points
model = Backbone(layers=5, channels=512)
metanet = MetaNet(model)
optim = torch.optim.AdamW(list(metanet.parameters()) + list(model.parameters()), lr=1e-4)
criterion = nn.MSELoss()
device = torch.device("cpu")
model.to(device)
metanet.to(device)
num_epochs = 1000
for epoch in tqdm.tqdm(range(num_epochs)):
optim.zero_grad()
task = np.random.choice(['A', 'B'])
if task == 'A':
task_info = torch.tensor([[1.0, 0.0]], device=device)
else:
task_info = torch.tensor([[0.0, 1.0]], device=device)
data_points = np.array(gen_data(task), dtype=np.float32)
train_x = torch.tensor(data_points, device=device)
# target = train_x.clone()
generated_weights = metanet(task_info)
for name, param in model.named_parameters():
print(generated_weights[name][0])
param.data.copy_(generated_weights[name][0])
training_steps = 100
batch_size = 64
pbar = tqdm.tqdm(range(training_steps))
for i in pbar:
x1 = train_x[torch.randint(train_x.size(0), (batch_size,))]
x0 = torch.randn_like(x1)
target = x1 - x0
t = torch.rand(x1.size(0))
xt = (1 - t[:, None]) * x0 + t[:, None] * x1
pred = model(xt, t) # also add t here
loss = ((target - pred)**2).mean()
loss.backward()
optim.step()
optim.zero_grad()
pbar.set_postfix(loss=loss.item())
backbone.py
:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class Block(nn.Module):
def __init__(self, channels=512):
super().__init__()
self.ff = nn.Linear(channels, channels)
self.act = nn.ReLU()
def forward(self, x):
return self.act(self.ff(x))
class Backbone(nn.Module):
def __init__(self, channels_data=2, layers=5, channels=512,
channels_t=512):
super().__init__()
self.channels_t = channels_t
self.in_projection = nn.Linear(channels_data, channels)
self.t_projection = nn.Linear(channels_t, channels)
self.blocks = nn.Sequential(*[
Block(channels) for _ in range(layers)
])
self.out_projection = nn.Linear(channels, channels_data)
def gen_t_embedding(self, t, max_positions=10000):
t = t * max_positions
half_dim = self.channels_t // 2
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.arange(half_dim, device=t.device).float().mul(-emb).exp()
emb = t[:, None] * emb[None, :]
emb = torch.cat([emb.sin(), emb.cos()], dim=1)
if self.channels_t % 2 == 1:
emb = F.pad(emb, (0, 1), mode='constant')
return emb
def forward(self, x, t):
x = self.in_projection(x)
t = self.gen_t_embedding(t)
t = self.t_projection(t)
x = x + t
x = self.blocks(x)
x = self.out_projection(x)
return x
metanet.py
:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from model.backbone import Backbone
class ResidualBlock(nn.Module):
"""
A simple residual block that adds the input to the output.
"""
def __init__(self, input_dim):
super().__init__()
self.fc = nn.Linear(input_dim, input_dim)
self.act = nn.ReLU()
def forward(self, x):
return x + self.act(self.fc(x))
class MetaNet(nn.Module):
"""
Meta network for generating target weights for the Backbone network,
given the input task info, which is a one-hot vector.
"""
def __init__(self, backbone: nn.Module, task_info_dim=2, hidden_dim=128):
"""
Args:
backbone: An instance of Backbone whose weights will be generated.
task_info_dim: Dimensionality of the one-hot task info vector.
hidden_dim: Dimensionality of the hidden layer in the meta-network.
"""
super().__init__()
self.backbone = backbone
total_params = sum(p.numel() for p in backbone.parameters())
self.residual = ResidualBlock(task_info_dim)
self.meta_fc1 = nn.Linear(task_info_dim, hidden_dim)
self.meta_act = nn.ReLU()
self.meta_fc2 = nn.Linear(hidden_dim, total_params)
def forward(self, task_info):
"""
Args:
task_info: A tensor of shape (batch_size, task_info_dim)
representing one-hot encoded task information.
Returns:
A dictionary mapping each parameter name (as in backbone.named_parameters())
to a tensor of generated weights with shape (batch_size, *parameter.shape).
"""
# Generate a flattened vector of all weights
x = self.residual(task_info)
x = self.meta_act(x)
x = self.meta_fc1(x)
x = self.meta_act(x)
x = self.meta_fc2(x) # shape: (batch_size, total_params)
generated_params = {}
start = 0
# Split and reshape the flattened vector to match each Backbone parameter.
for name, param in self.backbone.named_parameters():
num_params = param.numel()
# Slice out the piece for this parameter.
param_vector = x[:, start:start+num_params]
start += num_params
# Reshape so that each task in the batch gets a parameter tensor
generated_params[name] = param_vector.view(-1, *param.shape)
return generated_params