Dear all:
The algorithm in MAML https://arxiv.org/abs/1703.03400 includes outer and inner loop gradient computation.
Currently implementation will write the network as :
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
class Model:
def __init__(self):
self.vars = [nn.Parameters(torch.Tensor(3,3)),
nn.Parameter(torch.Tensor(3))]
def forward(self, x, vars):
if vars is None:
vars = self.vars
x = F.linear(x,w,b)
return x
Notably, We need to write every Tensor
and then build then network by F.linear, F.conv2d, F.relu, F.max_pooling2d
. This is every time costly and any modification of current network requires rewrite the hand-writen code.
Anyone have an elegant way to acheve MAML style network?