Pytorch does not backpropagate through a iterative tensor construction

Hi Pytorch people,

I am currently trying to build a tensor iteratively in Pytorch. Sadly the backprop does not work with the inplace operation in the loop. I already tried equivalent programs with stack for example. Does somebody know how I could build the tensor with a working backprop?

This is a minimal example which produces the error:

import torch

k=2
a =torch.Tensor([10,20])
a.requires_grad_(True)
b = torch.Tensor([10,20])
b.requires_grad_(True)

batch_size = a.size()[0]
uniform_samples = Uniform(torch.tensor([0.0]), torch.tensor([1.0])).rsample(torch.tensor([batch_size,k])).view(-1,k)
exp_a = 1/a
exp_b = 1/b
km = (1- uniform_samples.pow(exp_b)).pow(exp_a)

sticks = torch.zeros(batch_size,k)
remaining_sticks = torch.ones_like(km[:,0])
for i in range(0,k-1):
    sticks[:,i] = remaining_sticks * km[:,i]
    remaining_sticks *= (1-km[:,i])
sticks[:,k-1] = remaining_sticks
latent_variables = sticks

latent_variables.sum().backward()

The stack trace:

/opt/conda/conda-bld/pytorch_1570910687230/work/torch/csrc/autograd/python_anomaly_mode.cpp:57: UserWarning: Traceback of forward call that caused the error:
  File "/opt/conda/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/opt/conda/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/opt/conda/lib/python3.6/site-packages/traitlets/config/application.py", line 664, in launch_instance
    app.start()
  File "/opt/conda/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 563, in start
    self.io_loop.start()
  File "/opt/conda/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 148, in start
    self.asyncio_loop.run_forever()
  File "/opt/conda/lib/python3.6/asyncio/base_events.py", line 438, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.6/asyncio/base_events.py", line 1451, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.6/asyncio/events.py", line 145, in _run
    self._callback(*self._args)
  File "/opt/conda/lib/python3.6/site-packages/tornado/ioloop.py", line 690, in <lambda>
    lambda f: self._run_callback(functools.partial(callback, future))
  File "/opt/conda/lib/python3.6/site-packages/tornado/ioloop.py", line 743, in _run_callback
    ret = callback()
  File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 787, in inner
    self.run()
  File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 748, in run
    yielded = self.gen.send(value)
  File "/opt/conda/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 361, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/opt/conda/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 268, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/opt/conda/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 541, in execute_request
    user_expressions, allow_stdin,
  File "/opt/conda/lib/python3.6/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/opt/conda/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 300, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/opt/conda/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2855, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2881, in _run_cell
    return runner(coro)
  File "/opt/conda/lib/python3.6/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3058, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3249, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/opt/conda/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-124-2bbdbc3af797>", line 16, in <module>
    sticks[:,i] = remaining_sticks * km[:,i]

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-124-2bbdbc3af797> in <module>
     19 latent_variables = sticks
     20 
---> 21 latent_variables.sum().backward()

/opt/conda/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    148                 products. Defaults to ``False``.
    149         """
--> 150         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    151 
    152     def register_hook(self, hook):

/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     97     Variable._execution_engine.run_backward(
     98         tensors, grad_tensors, retain_graph, create_graph,
---> 99         allow_unreachable=True)  # allow_unreachable flag
    100 
    101 

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [2]] is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I’m not exactly sure, what your use case is, but from the code perspective, it should work, if you use

remaining_sticks = remaining_sticks * (1-km[:,i])

instead of the inplace operation.
Additionally, you could also initialize sticks as a list and use torch.cat to create a tensor afterwards.

I’m still in awe how I could have missed that^^ Thanks a lot for that!
I am effectively trying to sample from a GEM distribution which uses an iterative construction. Let’s assume that v is a k-dimensional vector. Then the sample g, which then is GEM distributed (hand-waving), can be defined as:

g[i] = torch.prod(1-v[:i-1]) * v[i] 

The loop construction gives me a linear complexity instead of the quadratic one in the naive case. Do you possibly see a way to vectorize the expression? (The loop kind of feels strange to be honest^^)