Dear all,
I have a problem with the following optimization procedure. I want to optimize the input tensor x
to reduce the L2 norm of the activations. So, for each layer, I register the hook function that updates the Meter object by adding the L2 norm of each layer output in the self.layers
list. After forwarding x
to the model, I collect and sum them up to obtain the final loss.
import torch
import torchvision.models as models
class Meter:
def __init__(self):
self.layers = []
self.size = 0
def register_stats(self, output):
self.layers.append(output.norm(2))
self.size += 1
def get_loss(model, x):
leaf_nodes = [module for module in model.modules()
if len(list(module.children())) == 0]
stats = Meter()
def _get_activation():
def hook_fn(model, input, output):
stats.register_stats(output)
return hook_fn
hooks = register_hooks(leaf_nodes, _get_activation)
model(x)
loss = 0
for i in range(stats.size):
loss += stats.layers[i]
remove_hooks(hooks)
return loss
def register_hooks(leaf_nodes, hook):
hooks = []
for i, node in enumerate(leaf_nodes):
hooks.append(node.register_forward_hook(hook()))
return hooks
def remove_hooks(hooks):
for hook in hooks:
hook.remove()
if __name__ == '__main__':
x = torch.rand(1, 3, 224, 224).cuda()
x.requires_grad_()
patch_optimizer = torch.optim.SGD([x], lr=0.01, momentum=0.9, weight_decay=0)
model = models.vgg16(pretrained=True).cuda()
for param in model.parameters():
param.requires_grad = True
model.eval()
for step in range(10):
patch_optimizer.zero_grad()
loss = get_loss(model, x)
loss.backward()
patch_optimizer.step()
print(f'Loss {loss:10.5f}')
However, I get the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 4096]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Surprisingly, I was able to fix this problem with a “bad” solution, obtained with the following changes:
...
class Meter:
def __init__(self):
self.layers = []
self.size = 0
def register_stats(self, output):
self.layers.append(output)
self.size += 1
...
def get_loss(model, x):
.....
model(x)
loss = 0
for i in range(stats.size):
loss += stats.layers[i].norm(2)
.....
In particular, I simply compute the norm only at the end, while keeping the output layers stored in a list.
Can someone explain to me why the first solution is not working? Do you have an idea of how to solve this problem without keeping in memory each output layer?
Thanks in advance
Torch version 1.10.1
Torchvision version 0.11.2