I have a slow down computing issue

Hi all

i m a newbie on machine learning and doing a project.

But i m stuck in low computation speed on my code.

class geodesic(nn.Module):
def init(self):
super(geodesic,self).init()
def forward(self,A,B):
A_log=A.transpose(0,1).mean(dim=1,keepdim=True).squeeze()
B_log=B.transpose(0,1).mean(dim=1,keepdim=True).squeeze()
out=A_log.sub(B_log).norm(p=‘fro’,dim=1,keepdim=True)
out=out.mean()
return out

Above one is my simple geodesic distance calculation module.
The inputs are in a form with [batch,channel,vector] .
In general, batch,channel and vector sizes are 64,128,64 respectively.
Btw the calculation is very very slow.
i think it would be my fault, but i cannot find what is main reason of this phenomenon.

why is this module slow??
please help this novice.

following code is example of using above module.

ex)
some_optimizer.zero_grad()
self.geodesic=geodesic()

some_feature1=some_feature1.view(batch_size,channel_size,-1)
some_feature2=some_feature2.view(batch_size,channel_size,-1)
dist=self.geodesic(some_feature1,some_feature2)

total_loss=some_loss+dist
total_loss.backward()
some_optimizer.step()

Sorry all

i found this module is not a problem.

Actually, i m doing svd on learning and symmetric positive definite matrix conversion.
i checked each operation time. the svd on [batch,channel,H,W] and symmetric positive definite matrix conversion are main issues of my work.

So, i want to talk about my Batch SVD and SPD conversion.

followings are my modules.

def SPD_transform(U,D):
[batch_size,channel_size,H,W]=U.size()
left=U
singular=D
recon=[]
left=left.view(-1,left.size(2),left.size(3))
singular=singular.view(-1,singular.size(2))

for idx,vector in enumerate(singular):                                                   
    recon+=[torch.matmul(torch.matmul(left[idx,:,:],vector.diag()),left[idx,:,:].t())]   
                                                                                         
recon=torch.stack(recon)                                                                 
recon=recon.view(batch_size,channel_size,H,W)                                            
                                                                                         
return recon                                                                             

def svd(x):
if x.dim() == 2:
result = torch.svd(x)
else:
batches = x.shape[:-2]
other = x.shape[-2:]
flat = x.view((-1,) + other)
slices = flat.unbind(0)
U, D, V = [], [], []
# I wish I had a parallel_for
for i in range(flat.shape[0]):
u, d, v = torch.svd(slices[i])
U += [u]
D += [d]
V += [v]
U = torch.stack(U).view(batches + U[0].shape)
D = torch.stack(D).view(batches + D[0].shape)
V = torch.stack(V).view(batches + V[0].shape)
result = U, D, V
return result

In general, the input shape of svd is [64,128,8,8].
the inputs shape of spd are followings
left singular vector : [64,128,8,8]
singular value : [64,128,8]

Can i reduce computation time in my modules??