Out of Memory in a Language model RNN with meta learning

Hi there.
I am really new in the community.
I will like to ask about meta learning and out of memory errors.
I implemented a code similar to the explained in https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a, but now I have a issue with the memory.
I checked in cpu and gpu and both generate the same problem.

Some details about the code:
Variable declaration:

    model_foward = ClassifierWithState(RNNLM(args.n_vocab, args.layer, args.unit))
    model_backward = ClassifierWithState(RNNLM(args.n_vocab, args.layer, args.unit))
    optimizer = MetaLearner(None)
    meta_optimizer = torch.optim.SGD(optimizer.parameters(), lr=1.0)

MetaLearner Class:

class MetaLearner(nn.Module):
    """ Bare Meta-learner class
        Should be added: intialization, hidden states, more control over everything
    def __init__(self, model):
        super(MetaLearner, self).__init__()
        self.weights = torch.nn.Parameter(torch.Tensor(1, 2))

    def forward(self, forward_model, backward_model):
        """ Forward optimizer with a simple linear neural net
            forward_model: PyTorch module with parameters gradient populated
            backward_model: PyTorch module identical to forward_model (but without gradients)
              updated at the Parameter level to keep track of the computation graph for meta-backward pass
        f_model_iter = get_params(forward_model)
        b_model_iter = get_params(backward_model)
        for f_param_tuple, b_param_tuple in zip(f_model_iter, b_model_iter): # loop over parameters
            # Prepare the inputs, we detach the inputs to avoid computing 2nd derivatives (re-pack in new Variable)
            (module_f, name_f, param_f) = f_param_tuple
            (module_b, name_b, param_b) = b_param_tuple
            inputs = torch.autograd.Variable(torch.stack([param_f.grad.data, param_f.data], dim=-1))
            # Optimization step: compute new model parameters, here we apply a simple linear function
            dW = F.linear(inputs, self.weights).squeeze()
            param_b = param_b + dW
            # Update backward_model (meta-gradients can flow) and forward_model (no need for meta-gradients).
            module_b._parameters[name_b] = param_b
            param_f.data = param_b.data

Training cicle:

        # Progress the dataset iterator for sentences at each iteration.
        batch = train_iter.__next__()
        losses = []
        for j in six.moves.range(len(batch)):
            # print('{} / {} \r'.format(j, len(batch)))
            x, t = convert_examples(batch[j], self.device)
            loss = 0
            count = 0
            state = None
            batch_size, sequence_length = x.shape
            # Sequence Forward

            for i in six.moves.range(sequence_length):
                #    # Compute the loss at this time step and accumulate it
                state, loss_batch = self.model_foward(state, x[:, i], t[:, i])
                non_zeros = torch.sum(x[:, i] != 0, dtype=torch.float)
                loss += loss_batch * non_zeros
                count += int(non_zeros)
            loss.backward(retain_graph=True)  # retain_graph=True
            self._optimizer(self.model_foward, self.model_backward)

        meta_loss = sum(losses)
        # logging.info('meta loss: {}'.format(float(meta_loss.detach())))
        reporter.report({'loss': float(meta_loss.detach())}, meta_optimizer.target)
        reporter.report({'count': count}, meta_optimizer.target)
        if self.gradclip is not None:
            nn.utils.clip_grad_norm_(self.model_foward.parameters(), self.gradclip)
            nn.utils.clip_grad_norm_(self.model_backward.parameters(), self.gradclip) 

The LMRNN is a lstm network with 1 layers.
The network and metanetwork do not have a lot of hyperparameters to update and I try to update the metalearner after a number of inputted samples. However, the app still breaks due to the lack of memory.

Am I missing any part?.
I though that I am using shared weights but do is it required to declare it once more?