Creating a reference to a tensor for gradient backprop without holding onto its data

Hi,

for a custom Function implementing the straight-through estimator, I am using something like this:

def f1(x, vectors):
  ...

class _Straight_Through_Index(Function):
  @staticmethod
  def forward(ctx, x, vectors):
    vec_inds = f1(x, vectors)
    result = vectors[vec_inds]
    ctx.save_for_backward(result, vectors)
    return result

  @staticmethod
  def backward(ctx, grad_out):
    result, vectors = ctx.saved_tensors
    return grad_out, grad([result], [vectors], grad_out)[0]

Now I would like to avoid holding onto the data of result and vectors by save_for_backward, since I only need the reference for calling torch.autograd.grad.
I believe it would be possible (but hacky) to do

partial_ref_to_vectors = vectors.clone()
partial_ref_to_vectors.data.set_()
...
grad([result], [partial_ref_to_vectors], grad_out)

and the same could be done for result.

The need for this becomes larger if multiple operations like these are used subsequently.

Does anyone have advice on whether there is a more straightforward or safer way to implement this and whether the hack below should be avoided?

Best regards