Hi, I’m experimenting with a loss of this kind
$
\mathcal{L}{total} = \mathcal{L} + \alpha \cdot \frac{1}{N} \sum{i=1}^N |\mathcal{L} - \mathcal{L}{\epsilon_i}|
$
where $\mathcal{L}{\epsilon_i}$ is the loss of the model with weights pertubated by a vector $\epsilon_i$. The issue of courseis the second part of the loss function.
How do I implement this in pytorch? With GPT’s help I ended up at different versions of
def sum_of_loss_differences_optimized(model, criterion, inputs, target, num_perturbations=5, perturbation_strength=1e-4):
# Compute original loss
original_output = model(inputs)
original_loss = criterion(original_output, target)
loss_sum = 0.0
# Copy parameters
original_params = [param.clone() for param in model.parameters()]
for _ in range(num_perturbations):
print(f"Perturbation No. {_}")
perturbations = []
with torch.no_grad():
for param in model.parameters():
perturbation = (torch.rand_like(param) * 2 - 1) * perturbation_strength
param.add_(perturbation)
perturbations.append(perturbation)
# Compute pertubated loss
perturbed_output = model(inputs)
perturbed_loss = criterion(perturbed_output, target)
loss_sum += torch.abs(original_loss - perturbed_loss)
with torch.no_grad(): # Recover original parameters
for param, original_param in zip(model.parameters(), original_params):
param.copy_(original_param)
return loss_sum / num_perturbations
However, when I compute the overall loss like this:
loss = criterion(outputs, labels) + 0.1 * sum_of_loss_differences_optimized(model, criterion, inputs, labels, num_perturbations=5, perturbation_strength=1e-4) # Compute the loss
I end up with a RunTimeError during backpropagation:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 10]], which is output 0 of AsStridedBackward0, is at version 11; expected version 10 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
This line of code shows up in the detailed error message
perturbed_output = model(inputs)
How do I get this to run? Mathematically it should be possible do implement this.