# How to correctly implement nuclear norm in pytorch 0.4.1

In short :

Although pytorch 1.0 support matrix norm, I still have to use pytorch 0.4.1 due to low cuda version. I implement a nuclear loss function but have a terrible speed maybe because it contains 2 for-loop. How can I make it faster?

In detail :
This is my loss function:
“mat” is my out put which is an 4 dimensional tensor [batch, channel, weight, height]
I want to compute nuclear loss of each matrix in each batch and channel.

With this loss I train my network 10 times slower than before. I want to make it faster but dont know how.
Since torch.mm doesn’t support broadcast, I tried to use torch.matmul but failed.
I would be very grateful if someone teach me how to do.

I havent tried buy i guess you can do this

https://pytorch.org/docs/stable/torch.html?highlight=mm#torch.bmm

`torch.bmm(mat_trans.view(-1,s_x,_sy),mat[-1,s_x,s_y]).view(batch,channels,Sx,Sy)`

2 Likes

This is the modified version. Although it is getting faster, it is still very time consuming. But it can only be like this, I guess. Thank you very much for helping me solve this problem.

``````class NuclearLossFunc(nn.Module):
def __init__(self):
super(NuclearLossFunc, self).__init__()
return
def forward(self,mat):
loss = torch.zeros().cuda()
total_batch, total_channel, Sx, Sy = mat.size()
mat = mat.view(-1,Sx,Sy)
mat_trans = torch.transpose(mat,1,2)
m_total = torch.bmm(mat_trans,mat)
for i in range(m_total.size()):
loss += m_total[i,:,:].trace()
loss /= m_total.size()
return loss
``````

p.s. I didn’t use sqrt because it’s too slow. You will need a smaller weight to this loss function.

Hi,
According to trace docs

Soooo why don’t you compute the sum of all the elements of the batch and then the trace? Then you would skip the for loop and if i’m not very **** at this time of the night, it should be equivalent.

``````class NuclearLossFunc(nn.Module):
def __init__(self):
super(NuclearLossFunc, self).__init__()
return
def forward(self,mat):
loss = torch.zeros().cuda()
total_batch, total_channel, Sx, Sy = mat.size()
mat = mat.view(-1,Sx,Sy)
mat_trans = torch.transpose(mat,1,2)
m_total = torch.bmm(mat_trans,mat)
loss = m_total.sum(0).trace()
loss /= (total_batch*total_channel) #Review this line
return loss
``````

Regards
Juan

I did not notice this! Thanks for your reminder, the program is getting a lot faster now.