Removing spectral norm changes the output values

In my current work, I need to use spectral norms (SN), and I want to export my model to onnx, unfortunately, SN messes with the onnx export.
When trying to remove the SN, export to onnx and put SN back on, I realized that removing the spectral norm and putting in back in changes the weights, and their effects on the output.

Let me illustrate this with the following sample code:

import torch


t0 = torch.randn(2000, 3, 4, 4)
c1 = torch.nn.Conv2d(3, 1, 3)
w1 = c1.weight.data.clone()
o1 = c1(t0)


c2 = torch.nn.utils.spectral_norm(c1)
w2 = c2.weight.data.clone()
o2 = c2(t0)


c3 = torch.nn.utils.remove_spectral_norm(c2)
w3 = c3.weight.data.clone()
o3 = c3(t0)

c4 = torch.nn.utils.spectral_norm(c3)
w4 = c4.weight.data.clone()
o4 = c4(t0)

print(f"(o1==o2).float().mean(): {torch.eq(o1,o2).float().mean()}")
print(f"(o2==o3).float().mean(): {torch.eq(o2,o3).float().mean()}")
print(f"(o3==o4).float().mean(): {torch.eq(o3,o4).float().mean()}")
print(f"(o1==o3).float().mean(): {torch.eq(o1,o3).float().mean()}")
print(f"(o2==o4).float().mean(): {torch.eq(o2,o4).float().mean()}")
print()
print(f"(w1==w2).float().mean(): {torch.eq(w1,w2).float().mean()}")
print(f"(w2==w3).float().mean(): {torch.eq(w2,w3).float().mean()}")
print(f"(w3==w4).float().mean(): {torch.eq(w3,w4).float().mean()}")
print(f"(w1==w3).float().mean(): {torch.eq(w1,w3).float().mean()}")
print(f"(w2==w4).float().mean(): {torch.eq(w2,w4).float().mean()}")

Essentially:

  • c1, c2, c3, and c4 are the same conv layer, but with SN added/removed several times
  • o1,o2,o3,o4 correspond to the output of these layers with the same input
  • w1,w2,w3,w4 correspond to the weights of the layers

The output of the above script is typically as follows:

(o1==o2).float().mean(): 0.0
(o2==o3).float().mean(): 1.0
(o3==o4).float().mean(): 0.0533749982714653
(o1==o3).float().mean(): 0.0
(o2==o4).float().mean(): 0.0533749982714653

(w1==w2).float().mean(): 1.0
(w2==w3).float().mean(): 0.0
(w3==w4).float().mean(): 1.0
(w1==w3).float().mean(): 0.0
(w2==w4).float().mean(): 0.0

My understanding of SN is not very deep, and I do not fully understand the math behind it, but I think it is roughly like so:

  • Spectral Norm enforces a constraint on the weights of the layer it is applied on, forbidding some possible values
  • This constraint is at least partially implemented in Pytorch by decomposing the weight matrix of the affected layer into weight, weight_u, and weight_v somehow

Given that, here is what I think of the results:

  1. o1!=o2: this seems logical, since c1 was not subject to any restrictions, therefore it is to be expected that applying SN can change the layer’s behavior
  2. o2==o3: this seems logical, since lifting the SN restriction should not modify the functionality of the layer
  3. o3!=o4: this is unexpected, since c2 should have fit the SN restriction, and c3 represents a function that is equivalent to c2… c3 should therefore also fit the SN restriction without modification… or so I thought. Interestingly, running the same script several times actually sometimes give me o3==o4, which would fit my expectations… but not always.
  4. o1!=o3: this follows from 1
  5. o2!=o4: this is what prompted me to investigate, and it seems the issue comes from 3
  6. w1 == w2: not exactly what I expected, but understandable. Looks like the SN restriction is fulfilled by doing extra steps during the convolution using weight_u and weight_v, but weight is kept as it was.
  7. w2!=w3: logical, since we need o2==o3, weight_u and weight_v must be directly incorporated back into weight at this step
  8. w3==w4: logical, this is just a repeat of 6
  9. w1!=w3: this follows from 7
  10. w2!=w4: not exactly what I expected, but understandable given 7

Is my understanding correct?
What method can I use to temporarily remove SN from my model, and then put it back, without it messing up the already learned weights?