With backward(retain_graph=True)
I can keep the current graph for future backprops. I understand that the last backprop then should have retain_graph=False
in order to free the graph. However, at the point of the backward pass I do not have this information (yet). Therefore, is there any way to manually free the graph? (Hopefully, other than detaching each Variable or running a backward without update?)
Hi,
Whenever the output Variable
will go out of scope in python, the whole graph will be deleted.
By default, some intermediary buffers are freed even before that to reduce peak memory usage (this is what is disabled when using retain_graph=True
). But the graph and all intermediary buffers are only kept alive as long as they are accessible from python (usually from the output Variable
), so running the last backward
with retain_graph=True
will only keep the intermediary buffers alive until they get freed with the rest of the graph when the python Variable
goes out of scope. So you don’t need to manually free the graph. If the output Variable
does not go out of scope in python, you can call del your_out_variable
so that it is deleted (and the graph associated to it will be as well).
If I try that and didn’t work is there something else I can do?
If that did not free the graph, that means that you have other python objects that reference it.
like the last node or a node inside the computation graph?
I dont think I reference the final loss again…
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required
import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot
import copy
import itertools
import sys
from collections import OrderedDict
from pdb import set_trace as st
#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
def forward(self, input):
'''
Note that input.size(0) is usually the batch size.
So what it does is that given any input with input.size(0) # of batches,
will flatten to be 1 * nb_elements.
'''
batch_size = input.size(0)
out = input.view(batch_size,-1)
return out # (batch_size, *size)
def get_cifar10():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader, testloader
def load_new_params(optimizer, params):
optimizer.param_groups = []
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
optimizer.add_param_group(param_group)
def reload_param_groups(opt, params):
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
# replace params
params = list(params)
if isinstance(params[0], dict):
raise ValueError(f'The hacked higher version does not support proper pytorch grouped params yet.')
opt.param_groups[0]['params'] = params
# opt.param_groups = []
# param_groups = list(params)
# if len(param_groups) == 0:
# raise ValueError("optimizer got an empty parameter list")
# if not isinstance(param_groups[0], dict):
# param_groups = [{'params': param_groups}]
# for param_group in param_groups:
# opt.add_param_group(param_group)
class MySGD(Optimizer):
def __init__(self, params, trainable_opt_params, trainable_opt_state):
defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
super().__init__(params, defaults)
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
eta = self.param_groups[0]['trainable_opt_params']['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.01*eta(prev_lr).view(1)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
p_new = p - lr*g
group['params'][p_idx] = p_new
# fake returns
self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
# update model
# new_params = self.param_groups[0]['params']
# new_params = self._track_higher_grads_for_new_params(new_params, self._track_higher_grads)
# self._fmodel.update_params(new_params)
higher.register_optim(MySGD, TrainableSGD)
def main():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# get dataloaders
trainloader, testloader = get_cifar10()
criterion = nn.CrossEntropyLoss()
# get trainable opt params
hidden = torch.randn(size=(1,1),requires_grad=True)
print(f'-> hidden = {hidden}')
eta = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1,bias=False)),
('sigmoid', nn.Sigmoid())
]))
lr = 0.01
meta_params = []
meta_params.append( {'params': hidden, 'lr':lr} )
meta_params.append( {'params': eta.parameters(), 'lr':lr} )
# get meta optimizer
#meta_opt = torch.optim.SGD(meta_params)
meta_opt = torch.optim.Adam(meta_params)
#
trainable_opt_params = {'eta':eta, 'hidden':hidden}
trainable_opt_state = {'prev_lr':hidden}
#inner_opt = MySGD(eta.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
# diffopt = higher.optim.get_diff_optim(
# inner_opt,
# eta.parameters(), # for this hack it can be anything
# fmodel=None, # None
# device=device,
# override=None, # None default
# track_higher_grads=True # True default
# )
# do meta-training/ outerloop argmin L^val(theta)
nb_outer_steps = 2 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
meta_opt.zero_grad()
if outer_i >= nb_outer_steps:
break
# sample child_model
child_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
('relu1', nn.ReLU()),
('Flatten', Flatten()),
('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
]))
# do inner-training: ~ argmin L^train(psi)
nb_inner_steps = 3
print('==== Inner Loop ====')
fmodel = higher.patch.monkeypatch(
child_model,
device,
copy_initial_weights=True # True default
)
inner_opt = MySGD(child_model.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
diffopt = higher.optim.get_diff_optim(
inner_opt,
child_model.parameters(), # for this hack it can be anything
fmodel=fmodel, # None
device=device,
override=None, # None default
track_higher_grads=True # True default
)
for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
if inner_i >= nb_inner_steps:
break
print(f'-> outer_i = {outer_i}')
print(f'-> inner_i = {inner_i}')
print(f'hidden^<{inner_i}> = {hidden}')
logits = fmodel(inner_inputs)
inner_loss = criterion(logits, inner_targets)
print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
#child_model_params = [{'params':child_model.parameters()}]
child_model_params = child_model.parameters()
reload_param_groups(diffopt, child_model_params)
diffopt._fmodel = fmodel
diffopt.step(inner_loss)
print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
print(f'hidden^<{inner_i}> = {hidden}')
# compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
outer_outputs = fmodel(outer_inputs)
meta_loss = criterion(outer_outputs, outer_targets) # L^val
#grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
print('\n---- Outer loop print statements ----')
print(f'----> outer_i = {outer_i}')
print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
#print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
meta_loss.backward()
print(f'hidden.grad = {hidden.grad}')
assert hidden.grad is not None
print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print(f'>> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
print(f'>> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
del meta_loss
meta_opt.zero_grad()
print()
if __name__ == "__main__":
main()
print('---> Done\a')
Well you definitely reference outer_outputs
and outer_targets
, so only the criterion()
part of the graph can potentially be freed.
this works:
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required
import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot
import copy
import itertools
import sys
from collections import OrderedDict
from pdb import set_trace as st
#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
def forward(self, input):
'''
Note that input.size(0) is usually the batch size.
So what it does is that given any input with input.size(0) # of batches,
will flatten to be 1 * nb_elements.
'''
batch_size = input.size(0)
out = input.view(batch_size,-1)
return out # (batch_size, *size)
def get_cifar10():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader, testloader
class MySGD(Optimizer):
def __init__(self, params, trainable_opt_params, trainable_opt_state):
defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
super().__init__(params, defaults)
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
eta = self.param_groups[0]['trainable_opt_params']['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.01*eta(prev_lr).view(1)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
p_new = p - lr*g
group['params'][p_idx] = p_new
# fake returns
self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
higher.register_optim(MySGD, TrainableSGD)
def main():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(f'device = {device}')
# get dataloaders
trainloader, testloader = get_cifar10()
criterion = nn.CrossEntropyLoss()
# get inner opt
hidden = torch.randn(size=(1,1), device=device, requires_grad=True)
print(f'-> hidden = {hidden}')
eta = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1, bias=False)),
('sigmoid', nn.Sigmoid())
])).to(device)
trainable_opt_params = {'eta':eta, 'hidden':hidden}
trainable_opt_state = {'prev_lr':hidden}
# get outer opt
lr = 0.05
meta_params = []
meta_params.append( {'params': hidden, 'lr':lr} )
meta_params.append( {'params': eta.parameters(), 'lr':lr} )
#meta_opt = torch.optim.SGD(meta_params)
meta_opt = torch.optim.Adam(meta_params)
# do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
nb_outer_steps = 10 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
outer_inputs, outer_targets = outer_inputs.to(device), outer_targets.to(device)
meta_opt.zero_grad()
if outer_i >= nb_outer_steps:
break
#
child_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
('relu1', nn.ReLU()),
('Flatten', Flatten()),
('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
])).to(device)
# do inner-training: ~ argmin L^train(theta)
nb_inner_steps = 2
trainable_opt_params = {'eta':eta, 'hidden':hidden}
trainable_opt_state = {'prev_lr':hidden}
print(trainable_opt_state)
child_model_params = [{'params':child_model.parameters()}]
inner_opt = MySGD(child_model_params, trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
print('==== Inner Loop ====')
with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=True) as (fmodel, diffopt):
for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
inner_inputs, inner_targets = inner_inputs.to(device), inner_targets.to(device)
if inner_i >= nb_inner_steps:
break
print(f'-> outer_i = {outer_i}')
print(f'-> inner_i = {inner_i}')
print(f'hidden^<{inner_i}> = {hidden}')
logits = fmodel(inner_inputs)
inner_loss = criterion(logits, inner_targets)
print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
diffopt.step(inner_loss)
print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
print(f'hidden^<{inner_i}> = {hidden}')
# compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
outer_outputs = fmodel(outer_inputs)
meta_loss = criterion(outer_outputs, outer_targets) # L^val
meta_loss.backward()
#grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
print('\n---- Outer loop print statements ----')
print(f'----> outer_i = {outer_i}')
print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
#print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
print(f'hidden.grad = {hidden.grad}')
assert hidden.grad is not None
assert eta.fc.weight is not None
print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print(f'>> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
print(f'>> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
print()
if __name__ == "__main__":
main()
print('---> Done\a')
I got lucky. It breaks if you uncomment:
trainable_opt_state = {'prev_lr':hidden}
inside here:
# do inner-training: ~ argmin L^train(theta)
nb_inner_steps = 2
trainable_opt_params = {'eta':eta, 'hidden':hidden}
trainable_opt_state = {'prev_lr':hidden}
print(trainable_opt_state)
child_model_params = [{'params':child_model.parameters()}]
inner_opt = MySGD(child_model_params, trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
print('==== Inner Loop ====')
I believe is because the next time I compute meta_loss
it has a reference to the previous computation graph because prev_lr
belongs to the previous iteration or the previous outerloop.