I am trying to write code for some of the Meta-Learning algorithms. I understand that there are a few packages available for easy and hassle-free implementation of Meta-Learning algorithms (higher, pytorch-meta) but I want to understand a few things conceptually.
Recently, a few meta-learning algorithm implementations such as Learning to Reweight, Meta-Weight Net, etc. have not been using higher
or pytorch-meta
, and, instead, have been using a custom nn.Module
(see code below) written by Daniel(link for this code: https://github.com/danieltan07/learning-to-reweight-examples/blob/master/meta_layers.py). Basically, itâs the usual PyTorch code for supervised learning with the only change being: using custom nn.Module
instead of the nn.Module
.
I am pasting the relevant code below. The nn.Module
shown below (Danielâs code) is what people have been using for their meta learning algorithms thereby not requiring the additional packages I have mentioned above.
class MetaModule(nn.Module):
# adopted from: Adrien Ecoffet https://github.com/AdrienLE
def params(self):
for name, param in self.named_params(self):
yield param
def named_leaves(self):
return []
def named_submodules(self):
return []
def named_params(self, curr_module=None, memo=None, prefix=''):
if memo is None:
memo = set()
if hasattr(curr_module, 'named_leaves'):
for name, p in curr_module.named_leaves():
if p is not None and p not in memo:
memo.add(p)
yield prefix + ('.' if prefix else '') + name, p
else:
for name, p in curr_module._parameters.items():
if p is not None and p not in memo:
memo.add(p)
yield prefix + ('.' if prefix else '') + name, p
for mname, module in curr_module.named_children():
submodule_prefix = prefix + ('.' if prefix else '') + mname
for name, p in self.named_params(module, memo, submodule_prefix):
yield name, p
def update_params(self, lr_inner, first_order=False, source_params=None, detach=False):
if source_params is not None:
for tgt, src in zip(self.named_params(self), source_params):
name_t, param_t = tgt
# name_s, param_s = src
# grad = param_s.grad
# name_s, param_s = src
grad = src
if first_order:
grad = to_var(grad.detach().data)
tmp = param_t - lr_inner * grad
self.set_param(self, name_t, tmp)
else:
for name, param in self.named_params(self):
if not detach:
grad = param.grad
if first_order:
grad = to_var(grad.detach().data)
tmp = param - lr_inner * grad
self.set_param(self, name, tmp)
else:
param = param.detach_()
self.set_param(self, name, param)
def set_param(self,curr_mod, name, param):
if '.' in name:
n = name.split('.')
module_name = n[0]
rest = '.'.join(n[1:])
for name, mod in curr_mod.named_children():
if module_name == name:
self.set_param(mod, rest, param)
break
else:
setattr(curr_mod, name, param)
def detach_params(self):
for name, param in self.named_params(self):
self.set_param(self, name, param.detach())
def copy(self, other, same_var=False):
for name, param in other.named_params():
if not same_var:
param = to_var(param.data.clone(), requires_grad=True)
self.set_param(name, param)
Using such MetaModule
one can create MetaLinear
, MetaConv2D
, etc. which can be used instead of nn.Linear
, nn.Conv2D
:
class MetaLinear(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.Linear(*args, **kwargs)
self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
def forward(self, x):
return F.linear(x, self.weight, self.bias)
def named_leaves(self):
return [('weight', self.weight), ('bias', self.bias)]
class MetaConv2d(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.Conv2d(*args, **kwargs)
self.stride = ignore.stride
self.padding = ignore.padding
self.dilation = ignore.dilation
self.groups = ignore.groups
self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True))
if ignore.bias is not None:
self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
else:
self.register_buffer('bias', None)
def forward(self, x):
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def named_leaves(self):
return [('weight', self.weight), ('bias', self.bias)]
I have the following question:
-
Why should one need to create a
custom nn.Module
and then dosetattr(name, params)
to update nn.Parameters (as is done in theupdate_params
function inMetaModule
class) in a way that these operations are recorded in the computation graph as well? Why canât I directly usesetattr(name, params)
in my regular training loop (i.e.def train(*args, **kwargs)
function) with standard, built-innn.Module
? -
I do understand that thereâs another way to deal with this (link: [resolved] Implementing MAML in PyTorch). For instance, the
def forward(self, x)
function fornn.Module
can be modified todef forward(self, x, weights)
instead so that the code works. But I am not fully clear about this as well. I understand thenn.Parameters
donât record history and hence we need to operator on other Tensors and then copy those values tonn.Parameters
but I wonder if I can do what I want without having to resort to this technique as well.
Note: My question is somewhat similar to Second order derivatives in meta-learning. However, what I am asking is conceptual and not necessarily a request for a workaround. And the only response in that thread is mine, so I am still not clear about implementing meta-learning algorithms in PyTorch.