for a custom Function implementing the straight-through estimator, I am using something like this:
def f1(x, vectors): ... class _Straight_Through_Index(Function): @staticmethod def forward(ctx, x, vectors): vec_inds = f1(x, vectors) result = vectors[vec_inds] ctx.save_for_backward(result, vectors) return result @staticmethod def backward(ctx, grad_out): result, vectors = ctx.saved_tensors return grad_out, grad([result], [vectors], grad_out)
Now I would like to avoid holding onto the data of
vectors by save_for_backward, since I only need the reference for calling torch.autograd.grad.
I believe it would be possible (but hacky) to do
partial_ref_to_vectors = vectors.clone() partial_ref_to_vectors.data.set_() ... grad([result], [partial_ref_to_vectors], grad_out)
and the same could be done for
The need for this becomes larger if multiple operations like these are used subsequently.
Does anyone have advice on whether there is a more straightforward or safer way to implement this and whether the hack below should be avoided?