Replace diagonal elements with vector

If it doesn’t complain about required tensor being modified (and you are not working with .data), then it should be fine.

@SimonW When I put the the resulting variable in the optimizer, I’m getting the error:

ValueError: can’t optimize a non-leaf Tensor

In general you can’t optimize intermediate results. If this is for initializing the tensor, you can add result.detach().requires_grad_() as a parameter. However, note that the diagonal may change in optimization process.

I still get the same error. What I am trying to do is to create a lower triangular Variable and optimize it’s diagonal keeping it positive. This is my code in the constructor of my class.

    self.L_chol_cov_theta = Parameter(torch.tensor(torch.zeros(dim, dim).cuda(), requires_grad=True))
    self.log_diag_L_chol_cov_theta = Parameter(torch.tensor(torch.zeros(dim).cuda(), requires_grad=True))
    self.L = Parameter(torch.tensor(torch.tril(torch.randn(dim, dim)).cuda(), requires_grad=True))
    self.L_chol_cov_theta.data = torch.tril(self.L_chol_cov_theta.data)
    self.L_chol_cov_theta.data -= torch.diag(torch.diag(self.L_chol_cov_theta.data))
    self.L.data = self.L_chol_cov_theta.data + torch.diag(torch.exp(self.log_diag_L_chol_cov_theta.data))

And this is the code in the forward pass of my class:

    self.L_chol_cov_theta.data -= torch.diag(torch.diag(self.L_chol_cov_theta.data))
    self.L.data = self.L_chol_cov_theta.data + torch.diag(torch.exp(self.log_diag_L_chol_cov_theta.data))

For some reason though the gradients are not calculated and I cannot figure out why…

if you are working with .data, the operations won’t be tracked by autograd and gradients for the tracked ones may even be wrong.

see “what about .data?” secion of https://pytorch.org/2018/04/22/0_4_0-migration-guide.html .

@SimonW I already read all about it but thanks for sharing it. Could you propose an idea on how to refactor the abovementioned code to be compatible with pytorch 0.4? It’s not very detailed on the semantics wrt to detach() on the initialization of variables and on the forward pass. It’s kind of difficult to understand it being a new in Pytorch.

Please find below a part of the problematic code so as to understand the problem I’m facing with Autograd not tracking the changes to the noise and L params at the forward pass.
P.S: I removed a lot of unnecessary parts of the code.


import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn import ParameterList

class Model(torch.nn.Module):
    def __init__(self, dim):
        """
        Constructor.
        """
        super(Model, self).__init__()
   
        self.noise_vector = Parameter(torch.tensor(torch.zeros(D).cuda(), requires_grad=True))
        self.noise = Parameter(torch.tensor(torch.diag(torch.exp(self.noise_vector.data)).cuda(), requires_grad=True))

        self.L_chol_cov_theta = Parameter(torch.tensor(torch.randn(dim, dim).cuda(), requires_grad=True))
        self.log_diag_L_chol_cov_theta = Parameter(torch.tensor(torch.randn(dim).cuda(), requires_grad=True))
        self.L = Parameter(torch.tensor(torch.randn(dim, dim).cuda(), requires_grad=True))
        self.L_chol_cov_theta.data = torch.tril(self.L_chol_cov_theta.data)
        self.L_chol_cov_theta.data -= torch.diag(torch.diag(self.L_chol_cov_theta.data))
        self.L.data = self.L_chol_cov_theta.data + torch.diag(torch.exp(self.log_diag_L_chol_cov_theta.data))

    def forward(self):
        # update parameters
        self.L_chol_cov_theta.data -= torch.diag(torch.diag(self.L_chol_cov_theta.data))
        self.L.data = self.L_chol_cov_theta.data + torch.diag(torch.exp(self.log_diag_L_chol_cov_theta.data))

        self.noise.data = torch.diag(torch.exp(self.noise_vector.data))

        return torch.mm(self.L, self.noise_vector.view(-1,-1))

 custom_net = Model(5)
 optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

for epoch in range(10):
   optimizer.zero_grad()
   forward_pass_something = model()
   loss = calc_likelihood(samples, a_ground_truth) # calc a custom loss
   loss.backward()
   optimizer.step()

Hi Simon, I want to initialize a lower triangular matrix with positive diagonal in the constructor, and then optimize the variables as they are. Then in each forward pass, I want to manually go and manipulate the data of that matrix and make it triul with positive diagonal. Is it possible?

I know this is probably late but it is for the wanderers out there. Since torch v0.4.1, torch.diagonal() was implemented to work similar to np.diagonal(). You can use it with tensor.copy_() to achieve replacing the diagonal elements of an N-dimensional tensor with a vector as follows:

tensor.diagonal(dim1=-2, dim2=-1).copy_(vector)

This method is very general because it works on the last two dimensions of tensor even if they weren’t equal. Now, to do exactly as what was required by the original question:

D = 3
k = 0
# L_1 = np.tril(np.random.normal(scale=1., size=(D, D)), k=k)
# L_1[np.diag_indices_from(L_1)] = np.exp(np.diagonal(L_1))
L_1 = torch.randn(D, D).tril_(diagonal=k)
L_1.diagonal(dim1=-2, dim2=-1).exp_()
3 Likes

Old topic, but a new and easier way is to use torch.fill_diagonal_:

In [116]: t = torch.randint(20, (3, 4))                                                                                                                                 

In [117]: t                                                                                                                                                             
Out[117]: 
tensor([[10, 14, 16,  4],
        [ 8,  9,  9, 17],
        [ 3, 16,  5, 16]])

In [118]: t.fill_diagonal_(100)                                                                                                                                         
Out[118]: 
tensor([[100,  14,  16,   4],
        [  8, 100,   9,  17],
        [  3,  16, 100,  16]])

In [119]: t                                                                                                                                                             
Out[119]: 
tensor([[100,  14,  16,   4],
        [  8, 100,   9,  17],
        [  3,  16, 100,  16]])

Note the _ at the end of fill_diagonal_; it is an in-place operation. There is no out-of-place version to my knowledge.

To keep the original tensor unmodified, I did:
t.detach().clone().fill_diagonal_(100)

1 Like