Hi all

i m a newbie on machine learning and doing a project.

But i m stuck in low computation speed on my code.

class geodesic(nn.Module):

def **init**(self):

super(geodesic,self).**init**()

def forward(self,A,B):

A_log=A.transpose(0,1).mean(dim=1,keepdim=True).squeeze()

B_log=B.transpose(0,1).mean(dim=1,keepdim=True).squeeze()

out=A_log.sub(B_log).norm(p=‘fro’,dim=1,keepdim=True)

out=out.mean()

return out

Above one is my simple geodesic distance calculation module.

The inputs are in a form with [batch,channel,vector] .

In general, batch,channel and vector sizes are 64,128,64 respectively.

Btw the calculation is very very slow.

i think it would be my fault, but i cannot find what is main reason of this phenomenon.

why is this module slow??

please help this novice.

following code is example of using above module.

ex)

some_optimizer.zero_grad()

self.geodesic=geodesic()

some_feature1=some_feature1.view(batch_size,channel_size,-1)

some_feature2=some_feature2.view(batch_size,channel_size,-1)

dist=self.geodesic(some_feature1,some_feature2)

total_loss=some_loss+dist

total_loss.backward()

some_optimizer.step()