I think what you are trying to say is that if I want the nn.Parameters
to “record” history, the example I talked about above uses register_buffer
instead of nn.Parameter
as a neat hack.
I think it’s starting to make sense now. Here’s what I think is going on (code: Daniel’s code for ‘Learning To Reweight’ algorithm):
def train_lre():
net, opt = build_model() # uses MetaModule to create the model instead of nn.Module
meta_losses_clean = []
net_losses = []
plot_step = 100
smoothing_alpha = 0.9
meta_l = 0
net_l = 0
accuracy_log = []
for i in tqdm(range(hyperparameters['num_iterations'])):
net.train()
# Line 2 get batch of data
image, labels = next(iter(data_loader))
# since validation data is small I just fixed them instead of building an iterator
# initialize a dummy network for the meta learning of the weights
meta_net = LeNet(n_out=1)
meta_net.load_state_dict(net.state_dict())
if torch.cuda.is_available():
meta_net.cuda()
image = to_var(image, requires_grad=False)
labels = to_var(labels, requires_grad=False)
# Lines 4 - 5 initial forward pass to compute the initial weighted loss
y_f_hat = meta_net(image)
cost = F.binary_cross_entropy_with_logits(y_f_hat,labels, reduce=False)
eps = to_var(torch.zeros(cost.size()))
l_f_meta = torch.sum(cost * eps)
meta_net.zero_grad()
# Line 6 perform a parameter update
grads = torch.autograd.grad(l_f_meta, (meta_net.params()), create_graph=True)
meta_net.update_params(hyperparameters['lr'], source_params=grads)
# Line 8 - 10 2nd forward pass and getting the gradients with respect to epsilon
y_g_hat = meta_net(val_data)
l_g_meta = F.binary_cross_entropy_with_logits(y_g_hat,val_labels)
grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
# Line 11 computing and normalizing the weights
w_tilde = torch.clamp(-grad_eps,min=0)
norm_c = torch.sum(w_tilde)
if norm_c != 0:
w = w_tilde / norm_c
else:
w = w_tilde
# Lines 12 - 14 computing for the loss with the computed weights
# and then perform a parameter update
y_f_hat = net(image)
cost = F.binary_cross_entropy_with_logits(y_f_hat, labels, reduce=False)
l_f = torch.sum(cost * w)
opt.zero_grad()
l_f.backward()
opt.step()
return np.mean(acc_log[-6:-1, 1])
The trick here is to use the buffers (created via register_buffer
in MetaLinear
, etc.) to access nn.Parameter
to create trainable tensors (weight
and bias
) and use them (instead of nn.Parameters
which don’t record any history) via named_leaves
, named_params
, update_params
, and set_param
functions to do the meta-learning. The nn.Parameters
are updated via opt.step()
whereas all the intermediate nn.Parameter
updates required for meta-learning are handled via the buffer variables (viz. weight
and bias
created via ‘register_buffer’).
@albanD Sorry for bothering you so much but I think that does explain it, don’t you think?