I would be very grateful, if you could help me with the following issue.
In my network I need to use orthogonal matrices, which hold to be orthogonal during the process of training.
For this purpose I wrote a class Orth_mat(nn.Module), which builds an orthogonal matrix for the given parameters. These parameters are the input of Orth_mat’s forward method and at the same time the optimization parameters of my network. Therefore I had to use Variable API to perform the operations inside Orth_mat class:
def __init__(self, shape): super(Orth_mat, self).__init__() self.shape = shape ''' shape: shape of an orthogonal matrix to build ''' def forward(self, input): m, n = self.shape assert(len(input) == m*n-n*(n+1)/2) assert(m >= n) # constructing unity matrix mxn unity_mat = torch.FloatTensor(np.identity(n)) for i in range(m-n): col = torch.zeros(n) unity_mat = torch.cat([unity_mat, col.unsqueeze(0)], dim=0) unity_mat = unity_mat orth_mat = Variable(torch.FloatTensor(np.identity(m))) loop_len = n if m>n else n-1 for i in range(loop_len): # divising the parameters in groups corresponding to each Householder reflection init = 0 for l in range(i): init += m-l-1 params_local = input[init:init+m-i-1] # constructing the vector "uvec" for parametrization of each Householder reflection uvec = Variable(torch.zeros(m)) for j in range(m-i): if j != m-i-1: uvec_j = torch.cos(params_local[j]) for k in range(j): uvec_j = uvec_j*torch.sin(params_local[k]) else: uvec_j = torch.sin(params_local[j-1]) for k in range(j-1): uvec_j = uvec_j*torch.sin(params_local[k]) uvec.index_copy_(0, Variable(torch.LongTensor([j])), uvec_j) # constructing a Householder reflection hh_refl = Variable(torch.FloatTensor(np.identity(m))) - 2*torch.mm(uvec.unsqueeze(1), uvec.unsqueeze(0)) # constructing an orthogonal matrix orth_mat = torch.mm(hh_refl, orth_mat) # making the matrix rectangular if needed orth_mat = torch.mm(orth_mat, Variable(unity_mat)) return orth_mat
My network consist of the custom layers, which use these orthogonal matrices by multiplying them with the input. In the init method I initialize necessary number of orthogonal matrices which result in about 130000 floats including the parameters:
__init__(self, ....) bla bla bla ############ initialization of orthogonal matrices self.weights_orth = nn.Parameter(torch.FloatTensor(num_matrices, num_params).uniform_(0, 2*np.pi)) orth_mat = Orth_mat(self.shape) index_orth = 0 for i in range(num_matrices): if index_orth == 0: self.orth = orth_mat(self.weights_orth[i]).unsqueeze(0) else: self.orth = torch.cat([self.orth, orth_mat(self.weights_orth[i]).unsqueeze(0)], dim=0) index_orth += 1
The problem is that when initialization happens it takes almost 10 Gb of my CPU memory. Afterwards, during training, CPU memory consumption only grows, although apart from initialization I perform all the operations on cuda. On the other hand, I noticed that if I don’t use Variable API in the Orth_mat class, there is no memory issue, but in this case I can not compute gradient (at least with the latest stable version of pytorch 0.3.1).
First of all, I don’t understand why I have this memory problem with Variable API, which seems to be really strange. Secondly, is there more efficient way to perform this kind of orthogonal matrices parametrization with pytorch?
Thank you in advance!