Hi,
for a custom Function implementing the straight-through estimator, I am using something like this:
def f1(x, vectors):
...
class _Straight_Through_Index(Function):
@staticmethod
def forward(ctx, x, vectors):
vec_inds = f1(x, vectors)
result = vectors[vec_inds]
ctx.save_for_backward(result, vectors)
return result
@staticmethod
def backward(ctx, grad_out):
result, vectors = ctx.saved_tensors
return grad_out, grad([result], [vectors], grad_out)[0]
Now I would like to avoid holding onto the data of result
and vectors
by save_for_backward, since I only need the reference for calling torch.autograd.grad.
I believe it would be possible (but hacky) to do
partial_ref_to_vectors = vectors.clone()
partial_ref_to_vectors.data.set_()
...
grad([result], [partial_ref_to_vectors], grad_out)
and the same could be done for result
.
The need for this becomes larger if multiple operations like these are used subsequently.
Does anyone have advice on whether there is a more straightforward or safer way to implement this and whether the hack below should be avoided?
Best regards