Any Elegant way to implement MAML algorithm

Dear all:
The algorithm in MAML https://arxiv.org/abs/1703.03400 includes outer and inner loop gradient computation.
Currently implementation will write the network as :

import  torch
from    torch import nn
from    torch import optim
from    torch.nn import functional as F

class Model:
    
    def __init__(self):
        
        self.vars = [nn.Parameters(torch.Tensor(3,3)),
                     nn.Parameter(torch.Tensor(3))]
        
        
    def forward(self, x, vars):
        
        if vars is None:
            vars = self.vars
            
        x = F.linear(x,w,b)
        
        return x
  

Notably, We need to write every Tensor and then build then network by F.linear, F.conv2d, F.relu, F.max_pooling2d. This is every time costly and any modification of current network requires rewrite the hand-writen code.

Anyone have an elegant way to acheve MAML style network?

There is a discussion about MAML pointing to some implementations using the functional API.
Would they work for you?

@ptrblck, yes, what’s I mean troublesome coding is write lines by functional API line by line.

But I want to look for some way to achieve by avoiding write this.