Hi!
I’m trying to avoid retain_graph=true, but I can’t figure out where I’m doing an extra backward where I shouldn’t. Could someone please give me some guidance? Below is my code:
I’m porting a tensorflow meta-learning algorithm to PyTorch. It’s a variant of MAML so, there is an ‘inner loop’ of the algorithm where we do two backwards. This inner loop is handled by autograd.grad, and the outer loop wrapping is has a loss.backward(). From others’ implementation of MAML as well as many posts here about multiple backwards, I know that the outer loop should not need a retain_graph=true to work. But runtime error says I should, and I struggle to figure out where I got it wrong. Here’s my model: (what makes it different from MAML is an ‘inference network,’ which I implemented as a bunch of extra nn.module (it’s the self.encoder field in __init())
code I’m porting: GitHub - haebeom-lee/l2b: Tensorflow implementation of "Learning to Balance: Bayesian Meta-learning for Imbalanced and Out-of-distribution Tasks" (ICLR 2020 oral)
import torch
import numpy as np
from layers import *
from encoder import InferenceNetwork
from torch import autograd
NUMBER_OF_HIDDEN_CHANNELS = 32
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
class LearningToBalance():
def __init__(self, data_name : str, number_of_inner_gradient_steps : int,
ways : int, shots : int, inner_learning_rate : float,
outer_learning_rate : float, batch_size : int):
if data_name == 'cifar':
self.xdim = 32
self.number_of_channels = 32
elif data_name == 'mimgnet':
self.xdim = 84
self.number_of_channels = 32
else:
raise ValueError("Invalid dataset name")
# training parameters
self.number_of_inner_gradient_steps = number_of_inner_gradient_steps
self.number_of_classes = ways
# self.metabatch = metabatch
self.inner_lr = inner_learning_rate
self.outer_lr = outer_learning_rate
self.batch_size = batch_size # number of TASKs in a batch!
# balancing variables from InferenceNetwork
self.encoder = InferenceNetwork(ways, shots, data_name, True, True, True).to(DEVICE)
self.use_inner_step_size_vector = True # learnable, from Meta-SGD
self.use_zeta = True
self.use_gamma = True
self.use_omega = True
self.use_alpha = True # learn alpha, instead of using set value
"""model definition: 4 convolutional layers, followed, by 1 linear layer"""
self.number_of_convolutional_layer = 4
self.input_channel = 3
self.kernel_size = 3
self.number_of_hidden_channels = NUMBER_OF_HIDDEN_CHANNELS
self.parameter = self.get_parameter_maml('theta') # 'this is the 'theta'
self.alpha = self.get_parameter_maml('alpha')
"""training"""
self.optimizer = torch.optim.Adam(params=list(self.parameter.values()) + list(self.alpha.values()), lr=self.outer_lr,)
self.validation_interval = 50
def get_parameter_maml(self, name : str) -> dict[str, torch.Tensor]:
"""initializes self.parameters"""
# initializers
if name == 'theta':
init_convolution = lambda x : torch.nn.init.xavier_uniform_(x)
init_bias = lambda x : torch.nn.init.zeros_(x)
init_dense = lambda x : torch.nn.init.xavier_uniform_(x)
else: # name == 'alpha'
init_convolution = lambda x : torch.nn.init.constant_(x, 0.01)
init_bias = lambda x : torch.nn.init.constant_(x, 0.01)
init_dense = lambda x : torch.nn.init.constant_(x, 0.01)
parameters = {}
input_channel = self.input_channel
# convolutional layers
print(f"model, get_parameter, generating params {name}")
for l in range(1, self.number_of_convolutional_layer + 1):
parameters[f'convolution_{l}_weight'] = init_convolution(torch.empty(
size=[self.number_of_hidden_channels, input_channel, self.kernel_size, self.kernel_size],
requires_grad=True,
device=DEVICE
))
parameters[f'convolution_{l}_bias'] = init_bias(torch.empty(
# do NOT assign int to 'size' field! (type problem)
# for weight dimension [64, 3, 3, 3], we expect bias of [64]<-one dimension only!
self.number_of_hidden_channels,
requires_grad=True,
device=DEVICE
))
input_channel = self.number_of_hidden_channels
# linear layer
parameters[f'dense_weight'] = init_dense(torch.empty(
size=[self.number_of_classes, self.number_of_hidden_channels],
requires_grad=True,
device=DEVICE
))
parameters[f'dense_bias'] = init_bias(torch.empty(
size=[self.number_of_classes],
requires_grad=True,
device=DEVICE
))
return parameters
def forward_theta(self, x : torch.Tensor, theta : dict[str, torch.Tensor], name : str):
x = torch.reshape(x, [-1, self.input_channel, self.xdim, self.xdim])
for l in range(1, self.number_of_convolutional_layer + 1):
if (x.isnan().any()):
x = F.conv2d(x, theta[f'convolution_{l}_weight'], theta[f'convolution_{l}_bias'], stride=1, padding='same')
if (x.isnan().any()):
x = F.batch_norm(x, None, None, training=True)
if (x.isnan().any()):
x = F.relu(x)
x = torch.mean(x, dim=[2, 3])
x = DenseBlock_F(x, theta[f'dense_weight'], theta[f'dense_bias']) # just executes the calculation
return x
def zeta_update_initialization(self,
zeta : dict[str, torch.Tensor], theta : dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""update theta with zeta"""
theta_update_by_zeta = {}
if self.use_zeta:
for layer_key in zeta.keys():
if '_w' in layer_key:
# θ_0 = θ + zeta
theta_update_by_zeta[layer_key] = theta[layer_key] * (1. + zeta[layer_key])
elif '_b' in layer_key:
theta_update_by_zeta[layer_key] = theta[layer_key] + zeta[layer_key]
else:
assert("LearningToBalance::theta_update_initialization(): \n\tcheck ur dictionary keys!")
return theta_update_by_zeta
def omega_modify_loss_of_class(self,
omega: torch.Tensor, cross_entropy_per_class : torch.Tensor) -> torch.Tensor:
if self.use_omega:
inner_loss = torch.sum(cross_entropy_per_class * F.softmax(omega, -1))
else:
inner_loss = torch.sum(cross_entropy_per_class)
return inner_loss
def gamma_modeify_inner_step_learning_rate(self,
theta_tensor : torch.Tensor, gamma_tensor : torch.Tensor, delta : torch.Tensor) -> torch.Tensor:
"""change learning rate with gamma"""
if self.use_gamma:
return theta_tensor - delta * torch.exp(gamma_tensor)
else:
return theta_tensor - delta
def _inner_loop(self, x_train : torch.Tensor, y_train : torch.Tensor, train : bool,
omega : torch.Tensor, gamma : dict[str, torch.Tensor]) -> tuple[dict[str, torch.Tensor], list[torch.Tensor]]:
"""return: updated parameters theta, list of inner accuracy (length is number_of_inner_gradient_steps + 1)"""
accuracy_list = []
y_train = y_train.to(torch.int64)
param_cloned = {
k: torch.clone(v)
for k, v in self.parameter.items()
}
for _ in range(self.number_of_inner_gradient_steps):
if (x_train.isnan().any()):
print(f"forward batch_norm: nan in put")
inner_logits = self.forward_theta(x_train, param_cloned, name="inner_loop") # predictions
cross_entropy_per_class = F.cross_entropy(inner_logits, y_train)
inner_accuracy = Accuracy(inner_logits, y_train)
"""omega: modulate class-specific loss"""
inner_loss = self.omega_modify_loss_of_class(omega, cross_entropy_per_class)
# make computation graph, so 2nd-order derivativa can be calculated
# when we call .backward() on OUTER loss
# when train, DO make graph, to allow back propagation; when validate, NO make graph!
grads = autograd.grad(
outputs=inner_loss, inputs=param_cloned.values(), create_graph=train)
gradient_dictionary : dict[str, torch.Tensor] = dict(zip(param_cloned.keys(), grads))
accuracy_list.append(inner_accuracy)
"""inner gradient step"""
for layer_key in param_cloned.keys():
if self.use_alpha: # manually checked that alpha is bound
delta = self.alpha[layer_key] * gradient_dictionary[layer_key]
else:
delta = self.inner_lr * gradient_dictionary[layer_key]
"""use gamma: modulate task-specific learning rate"""
param_cloned[layer_key] = self.gamma_modeify_inner_step_learning_rate(
theta_tensor=param_cloned[layer_key], gamma_tensor=gamma[layer_key], delta=delta
)
# no gradient update for this last logit and accuracy
last_inner_logits = self.forward_theta(x_train, param_cloned, name="inner_loop last")
last_inner_accuracy = Accuracy(last_inner_logits, y_train)
accuracy_list.append(last_inner_accuracy.item())
return param_cloned, accuracy_list
def _outer_step_single_task(self, input_task : tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
train : bool) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]:
"""return: cross entropy, outer accuracy, KL, prediction, inner accuracy
\nouter objective w.r.t. one task"""
omega, gamma, zeta, KL = self.encoder(x_train, y_train, do_sample=train)
Cardinality_train = x_train.shape[0]
Cardinality_test = x_test.shape[0]
# scale KL term w.r.t train and test set sizes
KL /= (Cardinality_train + Cardinality_test)
"""zeta: modulate MAML initialization"""
theta_update_by_zeta = self.zeta_update_initialization(zeta, self.parameter)
self.parameter.update(theta_update_by_zeta)
"""inner gradient steps; omega & gamma for task_specific modulation"""
theta_update_by_inner_loop, inner_accuracy = self._inner_loop(x_train, y_train, train, omega, gamma)
# self.parameter.update(theta_update_by_inner_loop)
"""outer-loss & test_accuracy"""
logits_test = self.forward_theta(x_test, theta_update_by_inner_loop, name="outer loop")
cross_entropy = CrossEntropy(logits_test, y_test.to(torch.int64))
outer_accuracy = Accuracy(logits_test, y_test)
prediction = F.softmax(logits_test, -1)
return cross_entropy, outer_accuracy, KL, prediction, inner_accuracy
def _outer_step_(self, input_task_batch : list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], train : bool
):
"""return: loss, result dictionary
\nmade of _outer_step_single_task"""
# print(f"model _outer_step")
cross_entropy_list = []
outer_accuracy_list : list[torch.Tensor]= []
KL_list = []
prediction_list = []
inner_accuracy_list : list[list[torch.Tensor]]= []
for t in range(self.batch_size):
input_task = (input_task_batch[0][t], input_task_batch[1][t],
input_task_batch[2][t], input_task_batch[3][t])
cross_entropy, outer_accuracy, KL, prediction, inner_accuracy = self._outer_step_single_task(input_task, train)
cross_entropy_list.append(cross_entropy)
outer_accuracy_list.append(outer_accuracy)
inner_accuracy_list.append(inner_accuracy)
prediction_list.append(prediction)
KL_list.append(KL)
loss = torch.mean(torch.stack(cross_entropy_list))
KL = torch.mean(torch.stack(KL_list))
return loss, torch.mean(torch.stack(outer_accuracy_list)), KL, torch.mean(torch.Tensor(inner_accuracy_list), dim=0)
def train(self, train_dataloader, valid_dataloader, ):
"""return: train loss, valid loss, train accuracy dictionary, valid accuracy dicitonary"""
with autograd.detect_anomaly():
print(f"T-R-A-I-N")
print(f"train_dataloader length? {len(train_dataloader)}")
print(f"valid_dataloader length? {len(valid_dataloader)}")
# print(f"length of dataloader {len(train_dataloader)}")
train_loss = []
train_outer_accuracy = [] # outer means query; inner means support
train_inner_accuracy_fresh = [] # pre-adapting
train_inner_accuracy_adapt = [] # post-adapting
"""TODO: show valid loss s.t. it reflects when number of validation != number of training, or throw Error"""
valid_loss = []
valid_outer_accuracy = [] # outer means query; inner means support
valid_inner_accuracy_fresh = [] # pre-adapting
valid_inner_accuracy_adapt = [] # post-adapting
# print(f"yo")
for step, task_batch in enumerate(train_dataloader):
"""task_batch is a BATCH (list) of tasks, batch = batch_size,
shape: [task batch, number of classes, width, height, number of channels]
e.g. [5, 70 = way * (shot + query), 32, 32, 3 bc RGB]"""
self.optimizer.zero_grad()
outer_loss, outer_accuracy, KL, inner_accuracy = self._outer_step_(task_batch, train=True)
outer_loss.backward()
self.optimizer.step()
train_loss.append(outer_loss.item())
train_outer_accuracy.append(outer_accuracy)
train_inner_accuracy_fresh.append(inner_accuracy[0])
train_inner_accuracy_adapt.append(inner_accuracy[-1])
if step % self.validation_interval == 0: # it's validation time!
print(f"\tV-A-L-I-D")
for v_step, task_batch in enumerate(valid_dataloader):
# print(f"valid step {v_step}")
outer_loss, outer_accuracy, KL, inner_accuracy = self._outer_step_(task_batch, train=False)
valid_loss.append(outer_loss.item())
valid_outer_accuracy.append(outer_accuracy)
valid_inner_accuracy_fresh.append(inner_accuracy[0])
valid_inner_accuracy_adapt.append(inner_accuracy[-1])
train_accuracy = {
'outer': train_outer_accuracy,
'inner pre-adapt': train_inner_accuracy_fresh,
'inner post_adapt': train_inner_accuracy_adapt
}
valid_accuracy = {
'outer': valid_outer_accuracy,
'inner pre-adapt': valid_inner_accuracy_fresh,
'inner post_adapt': valid_inner_accuracy_adapt
}
return train_loss, valid_loss, train_accuracy, valid_accuracy
def test(self, test_dataloader):
"""return: test loss, test accuracy dictionary"""
print("T-E-S-T")
print(f"test_dataloader length? {len(test_dataloader)}")
test_loss = []
test_outer_accuracy = [] # outer means query; inner means support
test_inner_accuracy_fresh = [] # pre-adapting
test_inner_accuracy_adapt = [] # post-adapting
for step, task_batch in enumerate(test_dataloader):
# print(f"test step {step}")
outer_loss, outer_accuracy, KL, inner_accuracy = self._outer_step_(task_batch, train=False)
test_loss.append(outer_loss.item())
test_outer_accuracy.append(outer_accuracy)
test_inner_accuracy_fresh.append(inner_accuracy[0])
test_inner_accuracy_adapt.append(inner_accuracy[-1])
test_accuracy = {
'outer': test_outer_accuracy,
'inner pre-adapt': test_inner_accuracy_fresh,
'inner post_adapt': test_inner_accuracy_adapt
}
return test_loss, test_accuracy
Parameters are stored in a dictionary (4 convolution layers followed by one linear layer). They are updated manually. The ‘encoder’ network jumps in from time to time to calculate some variables (gamma, theta, omega) to join the calculation. My suspicion is that something’s up with the inference network (it basically generates summary statistics from the data–I’m confused by the tensorflow code so I left it as a learnable nn module). I’m missing something like autograd.detect_double_backward() (made-up function) that tells me exactly where I’m backwarding-when I shouldn’t, or some things to keep in mind when reading my code for bugs.
Thanks so much!