Proper way to create tensor from scalars

Suppose I have a scalar tensor t which requires_grad and I would like to merge this into a larger tensor, for example:

# Tensor created from scalar `t`; pseudo-code.
M = [[ t**2   , sin(t)      ],
     [ cos(t) , sqrt(t + 5) ]]

Then the new tensor M is used to perform some computations and finally the gradients (w.r.t t) are computed by calling backward on the result.

Right now I can think of two ways how to create the larger tensor M from the scalar t:

  • Use nested calls to torch.stack,
  • or create auxiliary tensors and then multiply and add them together.

Here’s a code example to illustrate what I mean:

from functools import partial
from torch import cos, sin, sqrt, stack, tensor
import torch

tensor = partial(tensor, dtype=torch.float64)

t = tensor(2, requires_grad=True)

# Option 1.
M = stack([
    stack([ t**2   , sin(t)      ]),
    stack([ cos(t) , sqrt(t + 5) ])
])

# Option 2.
M = (
             t**2 * tensor([[1, 0], [0, 0]])
    +      sin(t) * tensor([[0, 1], [0, 0]])
    +      cos(t) * tensor([[0, 0], [1, 0]])
    + sqrt(t + 5) * tensor([[0, 0], [0, 1]])
)

x = tensor([1, 2])
y = M @ x
result = y @ y

result.backward()
print(t.grad)

Now my questions is, what is the preferred way to accomplish this task? Maybe there exists even a more dedicated method? I am interested in all aspects, including performance of the resulting graph.

Thanks a lot in advance!

1 Like

You can just pass the values as a list to torch.tensor:

N = torch.tensor([[t**2, torch.sin(t)],
                  [torch.cos(t), torch.sqrt(t+5.)]])

Just calling tensor on lists doesn’t seem to work (any more)?

import torch
from torch import sin, cos, sqrt

t = torch.tensor(0.5, requires_grad=True, device='cuda')

M = [[ t**2   , sin(t)      ],
     [ cos(t) , sqrt(t + 5) ]]

z = torch.tensor(M)
z.device, z.requires_grad
# (device(type='cpu'), False)

are nested calls to stack the only option for this?

Here’s a quick nested stack implementation in case it’s useful to anyone:

import torch


def _nd_peek(x):
    """Return the first element, if any, of nested
    iterable ``x`` that is a ``torch.Tensor``.
    """
    if isinstance(x, torch.Tensor):
        return x
    elif isinstance(x, (tuple, list)):
        for el in x:
            res = _nd_peek(el)
            if res:
                return res
            
            
def _nd_stack(x, device):
    """Recursively stack ``x`` into ``torch.Tensor``, 
    creating any constant elements encountered on ``device``.
    """
    if isinstance(x, (tuple, list)):
        return torch.stack([_nd_stack(el, device) for el in x])
    elif isinstance(x, torch.Tensor):
        # torch element
        return x
    else:
        # torch doesn't like you mixing devices
        # so create constant elements correctly
        return torch.tensor(x, device=device)
    
    
def torch_array(x):
    """Convert ``x`` into ``torch.Tensor`` respecting the device
    and gradient requirements of scalars.
    """
    # work out if we should propagate a device
    device = getattr(_nd_peek(x), 'device', None)
    return _nd_stack(x, device)

this also handles mixing constant elements with tensor elements like:

torch_array([[0, t], [t, 0]])