Matrix multiplication slows down on CPU (single and double precisions)

Hi, I implemented a classical multiplicative NMF algorithm with PyTorch, but it slows down after iterations on CPU. It is significant for single precision, and it also happens for double precision. Pytorch is compiled from source, and it is tested on two different systems (Intel® Xeon® CPU E5-2680 v2 and Intel® Xeon® CPU E5-2650 v4).

It is basically a sequence of matrix multiplications, and of course the most time-consuming part is where the dense data matrix is an operand (can also be confirmed by profiling). The thing is, that particular multiplication step is getting slower. The same also happens when I use the argument out= in torch.mm. Is it an issue of PyTorch? or am I missing something?

  • Memory usage was not increased over time.
import torch
import time
class NMF():
    def __init__(self, data, r, TType=torch.DoubleTensor):
        """
        data: a Tensor.
        r: a small integer
        """
        self.TType = TType

        self.p, self.q = data.shape
        self.r = r
        p, q = self.p, self.q
        assert self.r<=self.p and self.r<=self.q

        # initialize V and W
        self.V_prev = TType(p, r)
        self.W_prev = TType(q, r).t() 

        self.V = TType(p, r).uniform_()
        self.W = TType(q, r).uniform_().t()

        self.data = data

    def update_V(self):
        XWt =  torch.mm(self.data, self.W.t()) # this line gets slower over iterations
        WWt =  torch.mm(self.W, self.W.t())
        VWWt = torch.mm(self.V, WWt)
        self.V = self.V.mul_(XWt).div_(VWWt)
    def update_W(self):
        VtX  = torch.mm(self.V.t(), self.data) # this line gets slower over iterations
        VtV  = torch.mm(self.V.t(), self.V)
        VtVW = torch.mm(VtV, self.W)
        self.W = self.W.mul_(VtX).div_(VtVW)
    def get_objective(self):
        outer = torch.mm(self.V, self.W)
        val = torch.sum(((self.data - outer)**2))
        return val
    def check_convergence(self,tol, verbose=True, check_obj=False):
        obj = None
        diff_norm_1 = torch.max(torch.abs(self.V_prev-self.V))
        diff_norm_2 = torch.max(torch.abs(self.W_prev-self.W))
        if check_obj:
            obj = self.get_objective()
        converged = diff_norm_1 < tol and diff_norm_2 < tol
        return (diff_norm_1, diff_norm_2, obj), converged
    def run(self, maxiter=100, tol=1e-5, check_interval=1, verbose=True, check_obj=False):
        if verbose:
            print("Starting...")
            print("p={}, q={}, r={}".format(self.p, self.q, self.r))
            if not check_obj:
                print("%6s\t%10s\t%10s\t%10s" % ("iter", "V_maxdiff", "W_maxdiff", "time"))
            else:
                print("%6s\t%10s\t%10s\t%10s\t%10s" % ("iter", "V_maxdiff", "W_maxdiff", "obj", "time" ))
            print('-'*80)

        t0 = time.time()
        t_start = t0

        for i in range(maxiter):
            self.V_prev.copy_(self.V)
            self.W_prev.copy_(self.W)

            self.update_V()
            self.update_W()
            if (i+1) % check_interval ==0:
                t1 = time.time()
                (v_maxdiff, w_maxdiff, obj), converged = self.check_convergence(tol, verbose, check_obj)
                if verbose:
                    if not check_obj:
                        print("%6d\t%10.4e\t%10.4e\t%10.5f" % (i+1, v_maxdiff, w_maxdiff, t1-t0))
                    else:
                        print("%6d\t%10.4e\t%10.4e\t%10.4e\t%10.5f" % (i+1, v_maxdiff, w_maxdiff,
                                                                            obj, t1-t0))
                if converged: break
                t0 = t1

        if verbose:
            print('-'*80)
            print("Completed. total time: {}".format(time.time()-t_start))
if __name__=='__main__':
    torch.manual_seed(95376)
    TType=torch.FloatTensor
    m = TType(1000, 10000).uniform_()
    nmf_driver = NMF(m, 60, TType)
    nmf_driver.run(3000, check_interval=100, check_obj=True)

An output is:

Starting...
p=1000, q=10000, r=60
  iter	 V_maxdiff	 W_maxdiff	       obj	      time
--------------------------------------------------------------------------------
   100	6.5212e-04	4.5762e-02	7.9534e+05	   3.34130
   200	3.4037e-04	3.3905e-02	7.7984e+05	   3.11377
   300	2.4420e-04	2.0304e-02	7.7440e+05	   3.21589
   400	2.3506e-04	9.8728e-03	7.7161e+05	   3.18056
   500	1.9985e-04	7.4172e-03	7.6984e+05	   3.19203
   600	1.5385e-04	7.9482e-03	7.6861e+05	   3.05583
   700	1.1968e-04	4.8654e-03	7.6769e+05	   3.23273
   800	1.0916e-04	4.7331e-03	7.6698e+05	   2.98794
   900	1.2066e-04	5.0110e-03	7.6642e+05	   2.98134
  1000	1.1725e-04	5.8762e-03	7.6596e+05	   2.99186
  1100	1.0439e-04	4.3106e-03	7.6557e+05	   2.98801
  1200	8.8587e-05	4.9135e-03	7.6525e+05	   3.00293
  1300	9.0437e-05	4.1634e-03	7.6497e+05	   3.13171
  1400	7.4662e-05	4.8460e-03	7.6473e+05	   3.12103
  1500	6.4019e-05	4.2907e-03	7.6452e+05	   3.25630
  1600	6.1292e-05	3.7585e-03	7.6434e+05	   3.44411
  1700	6.1473e-05	4.0500e-03	7.6418e+05	   3.75131
  1800	5.6544e-05	4.6178e-03	7.6403e+05	   4.17858
  1900	5.5443e-05	4.5151e-03	7.6390e+05	   4.64420
  2000	5.0263e-05	5.2219e-03	7.6378e+05	   5.22469
  2100	5.1829e-05	3.8038e-03	7.6367e+05	   5.99058
  2200	5.6377e-05	4.2342e-03	7.6357e+05	   6.83499
  2300	5.2219e-05	3.6758e-03	7.6348e+05	   7.53499
  2400	5.4819e-05	4.0480e-03	7.6340e+05	   8.42264
  2500	6.0137e-05	3.1748e-03	7.6332e+05	   9.34333
  2600	5.6073e-05	3.0228e-03	7.6325e+05	  10.41376
  2700	4.8665e-05	3.1417e-03	7.6318e+05	  11.47899
  2800	4.5583e-05	3.8174e-03	7.6312e+05	  12.37672
  2900	4.9257e-05	3.0648e-03	7.6306e+05	  13.32979
  3000	4.8155e-05	3.3406e-03	7.6300e+05	  14.12009
--------------------------------------------------------------------------------

This was because of a lot of denormal numbers created.

This issue is discussed here.