Let’s say that given a tensor of length 3 with requires_grad=True
, I want to manually create a 3x3 skew-symmetric matrix for that tensor.
As a PyTorch newbie, this is what I would expect should work:
def variant_1(x):
skew_symmetric_mat = torch.tensor([
[0, -x[2], x[1]],
[x[2], 0.0, -x[0]],
[-x[1], x[0], 0.0]
])
return skew_symmetric_mat
vec = torch.rand(3, requires_grad=True)
variant_1(vec).backward(torch.ones(3, 3)) # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
However, variant_1
fails with the runtime error mentioned in the snippet. I guess the underlying problem is that creating the tensor skew_symmetric_mat
and populating it with elements of the other tensor x
is not a differentiable operation and when initializing skew_symmetric_mat
, only the values of x
get copied into skew_symmetric_mat
, so there is no backward graph being created that would reference the elements of the tensor x
.
With this assumption in mind, I was able to write a fully working function variant_2()
, which doesn’t assign the element of x
to skew_symmetric_mat
, but rather multiplies the elements of x
by new tensors, which most probably results in proper creation of DAG:
def variant_2(x):
skew_symmetric_mat = torch.zeros(3, 3)
skew_symmetric_mat += x[0] * torch.tensor([
[0, 0, 0],
[0, 0, -1.0],
[0, 1.0, 0]
])
skew_symmetric_mat += x[1] * torch.tensor([
[0, 0, 1.0],
[0, 0, 0],
[-1.0, 0, 0.0],
])
skew_symmetric_mat += x[2] * torch.tensor([
[0, -1.0, 0],
[1.0, 0, 0],
[0, 0, 0.0],
])
return skew_symmetric_mat
vec = torch.rand(3, requires_grad=True)
variant_2(vec).backward(torch.ones(3, 3)) # Computes `vec.grad` just fine
My question is: variant_2
seems a bit too much verbose to my liking and it seem like it’s computationally wasteful too. Surely, there must be a way how to write a code that is as compact as variant_1()
while also being computationally efficient. How would you go about writing this?
P.S. Apologies for such a trivial question. I couldn’t even find the right terminology to Google a solution.