Can't flow gradients through new constructed tensor

When I try to construct a tensor (A 3x3 matrix in my case) from multiple other tensors that require grad, the resulting tensor does not have require_grad = True by default, and even if I set it manually, when backproping the tensors that were used to construct the matrix do not get gradients. Why is this so?

example code:

test0 = torch.ones(1, dtype=torch.float32, requires_grad=True)
test1 = torch.ones(1, dtype=torch.float32, requires_grad=True)
test2 = torch.ones(1, dtype=torch.float32, requires_grad=True)
test_mat = torch.tensor([[test0, test1, 0], [1, 0 ,0], [test2, test1, 1]])
print(test_mat.requires_grad)
sum = test_mat.sum()
sum.backward()
print(test0.grad is None)

False
True

How can I construct a tensor out of other tensors such that gradients flow properly? I want to assemble the individual entries into a matrix and multiply another tensor by it (a 3xN tensor), instead of optimizing a full 3x3 matrix (since some of its entries are constants, and I require it to have some properties such as symmetry, which can’t be enforced directly if I optimize all the entries).

edit: For now I am creating an empty 3x3 matrix and replacing elements directly. Is this the intended way?

Thank you.

Hi Yotam!

The short answer is that this is the way that the factory function
torch.tensor() works.

This is perfectly fine.

As an alternative – not necessarily better – you can use cat() and
stack() to compose test_mat out of the desired pieces:

>>> import torch
>>> print (torch.__version__)
1.12.0
>>>
>>> test0 = torch.ones (1, dtype = torch.float32, requires_grad = True)
>>> test1 = torch.ones (1, dtype = torch.float32, requires_grad = True)
>>> test2 = torch.ones (1, dtype = torch.float32, requires_grad = True)
>>>
>>> test_mat = torch.stack (
...     (
...         torch.cat ((test0, test1, torch.tensor ([0.0]))),
...         torch.tensor ([1.0, 0.0, 1.0]),
...         torch.cat ((test2, test1, torch.tensor ([1.0])))
...     )
... )
>>>
>>> test_mat
tensor([[1., 1., 0.],
        [1., 0., 1.],
        [1., 1., 1.]], grad_fn=<StackBackward0>)
>>>
>>> test_mat.sum().backward()
>>>
>>> test0.grad
tensor([1.])
>>> test1.grad
tensor([2.])
>>> test2.grad
tensor([1.])

It would be a matter of style and maybe the details of your use case
as to the preferred approach.

Best.

K. Frank