Backpropagation through List Comprehension

Hi everyone,
I want to create tensor with shape (N*N,2) from a tensor of shape (N,1) that contains all possible combinations (with duplicates). It is part of a bigger pipeline but I realised that the gradients do not flow; this is how the code looks like:

x_graph = torch.tensor([[-1], [0], [1]], dtype=torch.float, requires_grad=True)
stacked_x = torch.stack([torch.tensor([n,i], requires_grad=True) for n in x_graph for i in x_graph])

This is the content of stacked_x:

tensor([[-1., -1.],
        [-1.,  0.],
        [-1.,  1.],
        [ 0., -1.],
        [ 0.,  0.],
        [ 0.,  1.],
        [ 1., -1.],
        [ 1.,  0.],
        [ 1.,  1.]], grad_fn=<StackBackward0>)

However, calling torch.autograd.grad(stacked_x.sum(), x_graph)[0] yields the following mistake:

RuntimeError                              Traceback (most recent call last)
/tmp/ipykernel_547127/2517417709.py in <module>
----> 1 torch.autograd.grad(stacked_x.sum(), x_graph)[0]

~/miniconda3/envs/test_torch/lib/python3.7/site-packages/torch/autograd/__init__.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched)
    300         return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    301             t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs,
--> 302             allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run the backward pass
    303 
    304 

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

Probably, it is a simple mistake … thank you very much!

UPDATE: I found a solution using torch.combination, however I am still curious how I can solve the original error :slight_smile: thx!

Hi Andreas!

In this line of code, n and i are indeed pytorch tensors. However, the
Tensor factory function, torch.tensor(), is expecting python scalars,
so pytorch (helpfully) converts the tensors you pass in to python scalars.
This “breaks the computation graph,” backpropagation fails, and adding
requires_grad = True to your call to torch.tensor() doesn’t fix the
problem.

Avoiding torch.tensor() works. Here is a modified version of your code
that uses torch.cat() to build the tensors that you then stack():

>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> x_graph = torch.tensor([[-1], [0], [1]], dtype=torch.float, requires_grad=True)
>>> stacked_x = torch.stack ([torch.cat ([x_graph[n], x_graph[i]]) for n in range (len (x_graph)) for i in range (len (x_graph))]
... )
>>>
>>> stacked_x
tensor([[-1., -1.],
        [-1.,  0.],
        [-1.,  1.],
        [ 0., -1.],
        [ 0.,  0.],
        [ 0.,  1.],
        [ 1., -1.],
        [ 1.,  0.],
        [ 1.,  1.]], grad_fn=<StackBackward0>)
>>>
>>> torch.autograd.grad(stacked_x.sum(), x_graph)[0]
tensor([[6.],
        [6.],
        [6.]])

Best.

K. Frank

1 Like