Hi all
i m a newbie on machine learning and doing a project.
But i m stuck in low computation speed on my code.
class geodesic(nn.Module):
def init(self):
super(geodesic,self).init()
def forward(self,A,B):
A_log=A.transpose(0,1).mean(dim=1,keepdim=True).squeeze()
B_log=B.transpose(0,1).mean(dim=1,keepdim=True).squeeze()
out=A_log.sub(B_log).norm(p=‘fro’,dim=1,keepdim=True)
out=out.mean()
return out
Above one is my simple geodesic distance calculation module.
The inputs are in a form with [batch,channel,vector] .
In general, batch,channel and vector sizes are 64,128,64 respectively.
Btw the calculation is very very slow.
i think it would be my fault, but i cannot find what is main reason of this phenomenon.
why is this module slow??
please help this novice.
following code is example of using above module.
ex)
some_optimizer.zero_grad()
self.geodesic=geodesic()
some_feature1=some_feature1.view(batch_size,channel_size,-1)
some_feature2=some_feature2.view(batch_size,channel_size,-1)
dist=self.geodesic(some_feature1,some_feature2)
total_loss=some_loss+dist
total_loss.backward()
some_optimizer.step()