Hello there
I am testing a meta learning and copy the program from https://medium.com/huggingface/from-zero-to-research-an-introduction-to-meta-learning-8e16e677f78a.
The problem is that after a certain amount of iterations I got an OOM message.
The model is not really large and I am testing with the MNIST dataset, so I suppose that there should be fine.
In the original code, the loss.backward is written without keep gradients. But if I erase that, then I have a problem with the lack of gradient. I suppose that the memory should be released after losses = [] but it seems that it is not.
Any advice to solve this problem.
Regards.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
from sys import stdout
class MetaLearner(nn.Module):
""" Bare Meta-learner class
Should be added: intialization, hidden states, more control over everything
"""
def __init__(self, model):
super(MetaLearner, self).__init__()
self.weights = nn.Parameter(torch.Tensor(1, 2))
def forward(self, forward_model, backward_model):
""" Forward optimizer with a simple linear neural net
Inputs:
forward_model: PyTorch module with parameters gradient populated
backward_model: PyTorch module identical to forward_model (but without gradients)
updated at the Parameter level to keep track of the computation graph for meta-backward pass
"""
f_model_iter = get_params(forward_model)
b_model_iter = get_params(backward_model)
for f_param_tuple, b_param_tuple in zip(f_model_iter, b_model_iter): # loop over parameters
# Prepare the inputs, we detach the inputs to avoid computing 2nd derivatives (re-pack in new Variable)
(module_f, name_f, param_f) = f_param_tuple
(module_b, name_b, param_b) = b_param_tuple
inputs = torch.autograd.Variable(torch.stack([param_f.grad.data, param_f.data], dim=-1))
dims = len(inputs.data.shape)
# Optimization step: compute new model parameters, here we apply a simple linear function
dW = F.linear(inputs, self.weights).squeeze(dims-1)
param_b = param_b + dW
# Update backward_model (meta-gradients can flow) and forward_model (no need for meta-gradients).
module_b._parameters[name_b] = param_b
param_f.data = param_b.data
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def get_params(module, memo=None, pointers=None):
""" Returns an iterator over PyTorch module parameters that allows to update parameters
(and not only the data).
! Side effect: update shared parameters to point to the first yield instance
(i.e. you can update shared parameters and keep them shared)
Yields:
(Module, string, Parameter): Tuple containing the parameter's module, name and pointer
"""
if memo is None:
memo = set()
pointers = {}
for name, p in module._parameters.items():
if p not in memo:
memo.add(p)
pointers[p] = (module, name)
yield module, name, p
elif p is not None:
prev_module, prev_name = pointers[p]
module._parameters[name] = prev_module._parameters[prev_name] # update shared parameter pointer
for child_module in module.children():
for m, n, p in get_params(child_module, memo, pointers):
yield m, n, p
def train(forward_model, backward_model, optimizer, meta_optimizer, train_data, meta_epochs, device):
""" Train a meta-learner
Inputs:
forward_model, backward_model: Two identical PyTorch modules (can have shared Tensors)
optimizer: a neural net to be used as optimizer (an instance of the MetaLearner class)
meta_optimizer: an optimizer for the optimizer neural net, e.g. ADAM
train_data: an iterator over an epoch of training data
meta_epochs: meta-training steps
To be added: intialization, early stopping, checkpointing, more control over everything
"""
forward_model.train()
backward_model.train()
optimizer.train()
for meta_epoch in range(meta_epochs): # Meta-training loop (train the optimizer)
optimizer.zero_grad()
losses = []
print('#### Epoch: {}'.format(meta_epoch))
for batch_idx, (inputs, labels) in enumerate(train_data):
#for inputs, inputs in train_data: # Meta-forward pass (train the model)
stdout.write('Batch :{}\r'.format(batch_idx))
stdout.flush()
forward_model.zero_grad() # Forward pass
inputs, labels = inputs.to(device), labels.to(device)
output = forward_model(inputs)
loss = F.nll_loss(output, labels)
losses.append(loss)
loss.backward(retain_graph=True) # Backward pass to add gradients to the forward_model
optimizer(forward_model, # Optimizer step (update the models)
backward_model)
if batch_idx % 50:
meta_loss = sum(losses) # Compute a simple meta-loss
meta_loss.backward() # Meta-backward pass
meta_optimizer.step() # Meta-optimizer step
print(float(meta_loss.data))
optimizer.zero_grad()
losses = []
if batch_idx > 1000:
exit()
# print(loss.data)
def main():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
model_fwd = Net().to(device)
model_bwd = Net().to(device)
optimizer = MetaLearner(model_fwd).to(device)
meta_opt = optim.SGD(optimizer.parameters(), lr=0.1, momentum=0.5)
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=20, shuffle=True, **kwargs)
train(model_fwd, model_bwd, optimizer, meta_opt, train_loader, 10, device)
if __name__ == '__main__':
main()