Hi, I implemented a classical multiplicative NMF algorithm with PyTorch, but it slows down after iterations on CPU. It is significant for single precision, and it also happens for double precision. Pytorch is compiled from source, and it is tested on two different systems (Intel® Xeon® CPU E5-2680 v2 and Intel® Xeon® CPU E5-2650 v4).
It is basically a sequence of matrix multiplications, and of course the most time-consuming part is where the dense data matrix is an operand (can also be confirmed by profiling). The thing is, that particular multiplication step is getting slower. The same also happens when I use the argument out=
in torch.mm
. Is it an issue of PyTorch? or am I missing something?
- Memory usage was not increased over time.
import torch
import time
class NMF():
def __init__(self, data, r, TType=torch.DoubleTensor):
"""
data: a Tensor.
r: a small integer
"""
self.TType = TType
self.p, self.q = data.shape
self.r = r
p, q = self.p, self.q
assert self.r<=self.p and self.r<=self.q
# initialize V and W
self.V_prev = TType(p, r)
self.W_prev = TType(q, r).t()
self.V = TType(p, r).uniform_()
self.W = TType(q, r).uniform_().t()
self.data = data
def update_V(self):
XWt = torch.mm(self.data, self.W.t()) # this line gets slower over iterations
WWt = torch.mm(self.W, self.W.t())
VWWt = torch.mm(self.V, WWt)
self.V = self.V.mul_(XWt).div_(VWWt)
def update_W(self):
VtX = torch.mm(self.V.t(), self.data) # this line gets slower over iterations
VtV = torch.mm(self.V.t(), self.V)
VtVW = torch.mm(VtV, self.W)
self.W = self.W.mul_(VtX).div_(VtVW)
def get_objective(self):
outer = torch.mm(self.V, self.W)
val = torch.sum(((self.data - outer)**2))
return val
def check_convergence(self,tol, verbose=True, check_obj=False):
obj = None
diff_norm_1 = torch.max(torch.abs(self.V_prev-self.V))
diff_norm_2 = torch.max(torch.abs(self.W_prev-self.W))
if check_obj:
obj = self.get_objective()
converged = diff_norm_1 < tol and diff_norm_2 < tol
return (diff_norm_1, diff_norm_2, obj), converged
def run(self, maxiter=100, tol=1e-5, check_interval=1, verbose=True, check_obj=False):
if verbose:
print("Starting...")
print("p={}, q={}, r={}".format(self.p, self.q, self.r))
if not check_obj:
print("%6s\t%10s\t%10s\t%10s" % ("iter", "V_maxdiff", "W_maxdiff", "time"))
else:
print("%6s\t%10s\t%10s\t%10s\t%10s" % ("iter", "V_maxdiff", "W_maxdiff", "obj", "time" ))
print('-'*80)
t0 = time.time()
t_start = t0
for i in range(maxiter):
self.V_prev.copy_(self.V)
self.W_prev.copy_(self.W)
self.update_V()
self.update_W()
if (i+1) % check_interval ==0:
t1 = time.time()
(v_maxdiff, w_maxdiff, obj), converged = self.check_convergence(tol, verbose, check_obj)
if verbose:
if not check_obj:
print("%6d\t%10.4e\t%10.4e\t%10.5f" % (i+1, v_maxdiff, w_maxdiff, t1-t0))
else:
print("%6d\t%10.4e\t%10.4e\t%10.4e\t%10.5f" % (i+1, v_maxdiff, w_maxdiff,
obj, t1-t0))
if converged: break
t0 = t1
if verbose:
print('-'*80)
print("Completed. total time: {}".format(time.time()-t_start))
if __name__=='__main__':
torch.manual_seed(95376)
TType=torch.FloatTensor
m = TType(1000, 10000).uniform_()
nmf_driver = NMF(m, 60, TType)
nmf_driver.run(3000, check_interval=100, check_obj=True)
An output is:
Starting...
p=1000, q=10000, r=60
iter V_maxdiff W_maxdiff obj time
--------------------------------------------------------------------------------
100 6.5212e-04 4.5762e-02 7.9534e+05 3.34130
200 3.4037e-04 3.3905e-02 7.7984e+05 3.11377
300 2.4420e-04 2.0304e-02 7.7440e+05 3.21589
400 2.3506e-04 9.8728e-03 7.7161e+05 3.18056
500 1.9985e-04 7.4172e-03 7.6984e+05 3.19203
600 1.5385e-04 7.9482e-03 7.6861e+05 3.05583
700 1.1968e-04 4.8654e-03 7.6769e+05 3.23273
800 1.0916e-04 4.7331e-03 7.6698e+05 2.98794
900 1.2066e-04 5.0110e-03 7.6642e+05 2.98134
1000 1.1725e-04 5.8762e-03 7.6596e+05 2.99186
1100 1.0439e-04 4.3106e-03 7.6557e+05 2.98801
1200 8.8587e-05 4.9135e-03 7.6525e+05 3.00293
1300 9.0437e-05 4.1634e-03 7.6497e+05 3.13171
1400 7.4662e-05 4.8460e-03 7.6473e+05 3.12103
1500 6.4019e-05 4.2907e-03 7.6452e+05 3.25630
1600 6.1292e-05 3.7585e-03 7.6434e+05 3.44411
1700 6.1473e-05 4.0500e-03 7.6418e+05 3.75131
1800 5.6544e-05 4.6178e-03 7.6403e+05 4.17858
1900 5.5443e-05 4.5151e-03 7.6390e+05 4.64420
2000 5.0263e-05 5.2219e-03 7.6378e+05 5.22469
2100 5.1829e-05 3.8038e-03 7.6367e+05 5.99058
2200 5.6377e-05 4.2342e-03 7.6357e+05 6.83499
2300 5.2219e-05 3.6758e-03 7.6348e+05 7.53499
2400 5.4819e-05 4.0480e-03 7.6340e+05 8.42264
2500 6.0137e-05 3.1748e-03 7.6332e+05 9.34333
2600 5.6073e-05 3.0228e-03 7.6325e+05 10.41376
2700 4.8665e-05 3.1417e-03 7.6318e+05 11.47899
2800 4.5583e-05 3.8174e-03 7.6312e+05 12.37672
2900 4.9257e-05 3.0648e-03 7.6306e+05 13.32979
3000 4.8155e-05 3.3406e-03 7.6300e+05 14.12009
--------------------------------------------------------------------------------