Modifying parts of tensors while preserving gradient computation

Hi,
I want to train a network by taking the gradient of a simulation rollout. Here is a small snippet of what I intend to differentiate:

for n steps do:
obs = get_observations(state)
  actions = get_actions(obs)
  next_state = simulation_step(state,actions)
  reward = get_reward(next_state)

Since I need all observations and rewards for loss computation after the rollout, I want to have something more like this:

for n steps do:
  obs_buf[i,...] = get_observations(state)
  actions = get_actions(obs_buf[i,...])
  next_state = simulation_step(state,actions)
  reward_buf[i,...] = get_reward(next_state)
loss = get_loss(obs_buf, rew_buf)

However, doing this with Torch tensors gives different error messages like “A view of a leaf variable that requires grad is being used in an in-place operation”, “Trying to backward through the graph a second time”, or that the version of a tensor does not match during gradient calculation. My current solution avoids these issues by using lists:

for n steps do:
  obs_list.append(get_observations(state))
  actions = get_actions(obs_list[i])
  next_state = simulation_step(state,actions)
  reward_list.append(get_reward(next_state))
loss = get_loss(obs_list, rew_list)

But I am worried about performance. I am still struggling with fully understanding how tensors are handled in autograd. A deeper insight into why Torch tensors don’t work and any suggestions on my Issue would be highly appreciated.

Thanks a lot :slight_smile:

Hey,
for anyone having similar issues, what I was missing is the fact that changing parts of a tensor by slicing is handled as an in-place operation. Therefore, adding .clone() before modifying the tensor solves the problem. The loop would then look like:

for n steps do:
  obs_buf = obs_buf.clone()
  obs_buf[i,...] = get_observations(state)
  actions = get_actions(obs_buf[i,...])
  next_state = simulation_step(state,actions)
  reward_buf = reward_buf.clone()  
  reward_buf[i,...] = get_reward(next_state)
loss = get_loss(obs_buf, rew_buf)

However, I am not sure if this is a smart way to do it since memory consumption for the buffers is basically multiplied by the number of rollout steps n. If anyone has a better idea than this or the list example above please feel free to share it :slight_smile: