Build a new Loss Function with Infinity Norm

Hi All! I am trying to write a new loss function, which takes infinity norm of the weights.
I am trying some thing on this line.

def l2_regu(mdl):
        l2_reg = None
        for W in mdl.parameters():
                if W.ndimension() < 2:
                        continue
                else:
                        w_tmp = W
                        if l2_reg is None:
                                l2_reg = (torch.max(torch.abs(w_tmp)))**2
                        else:   
                                l2_reg = l2_reg + (torch.max(torch.abs(w_tmp)))**2
      
        return l2_reg
Running this throws an error  saying:
File "train.py", line 208, in train
    oloss =  l2_regu(model)
  File "train.py", line 38, in l2_reg_ortho
    l2_reg = l2_reg + (torch.max(torch.abs(w_tmp)))**2
RuntimeError: cuda runtime error (2) : out of memory at /pytorch/torch/lib/THC/generic/THCStorage.cu:58
(torch3) bansa01@vita2:~/pytorch_wideres/tmp_inf/WideResNet-pytorch$

Every thing works good if I remove/replace with l2 loss, but with infinity norm it throws the above error.

Regards,
Nitin

Hi,

The error is that you run out of memory.
I guess your script when using the l2 norm is using almost all the memory.
Your implementation has intermediary results (and so uses a bit more memory).
I guess reducing your batch size should reduce the memory footprint and make it run nicely.

Thanks @albanD!

The issue is bit weird, It is actually working fine If I simply replace the ‘Infinity’ to just calculate the l2 norm, without any issue. Also I reduced the batch size to as low as 4. But still the same error persists. Do you think this could be due to some other reason?

Hi,
From your code, there is no “infinity”, you do this norm by doing torch.max(torch.abs(w_tmp))**2. If you were doing w_tmp.norm(2) then one is doing one operation while the other is doing 3.
What is the memory usage for the corresponding code that uses the l2 regularization?
If you could provide a small code snippet that reproduce the issue, that could be helpful.

Hello Alan, Somehow the Norm (Infinity) is always returning 1. and giving in-correct value. My code is as pointed out earlier this.

def l2_regu(mdl):
        l2_reg = None
        for W in mdl.parameters():
                if W.ndimension() < 2:
                        continue
                else:
                        w_tmp = W
                        if l2_reg is None:
                                l2_reg = (torch.max(torch.abs(w_tmp)))**2
                        else:   
                                l2_reg = l2_reg + (torch.max(torch.abs(w_tmp)))**2
      
        return l2_reg

Even though I kept my batch size as 4(very small), within an Epoch, I saw through nvidia-smi command the memory usage went from 100 MB to 12GB, throwing the CUDA MEMORY ERROR. If I simply replace that (torch.max(torch.abs(w_tmp)))**2 with l2 norm every thing works perfectly. I understand that there are 3 operations here as contrast to 1 operation in l2 norm, But I am getting memory error within an Epoch that too for a very small batch size.

Let’s experiment a bit…
Given a=torch.randn(5,5)
Then a.norm(p=np.inf) always returns 1.0
However torch.max(torch.abs(a)) correctly returns the abs value of the element with the maximum absolute value.

So either my understanding of the infinity norm is wrong, or torch.norm(p=np.inf) is doing something strange.

Moreover as noted in another thread torch.norm(p=np.inf) throws an error when applied to a Variable.

@jpeg729 Yeah, So If we apply norm to a tensor it returns 1, if applied to a variable returns an error, and when I replace it by torch.max(torch.abs(a)), it return correct valueif written in an standalone fashion, but when I embed it to a running code , it throws the CUDA MEMORY ERROR(Which is validated by the output of nvidia-smi, which shows the memory usage exploding!). where as the same function if I use l2 norm, every thing works fine.

So My conclusion is somehow torch.max(torch.abs(a)) is blowing up things, but why and how should I fix it… I am not able to understand. My aim is to add a regularizer which helps in getting a l-infinity norm.

My guess is that reading the source code for torch.norm might be worthwhile.