I’m working with a model which unrolls over a few steps, and I would like to scale the gradient at each of these steps. If I want to scale each unroll step by the same amount (say by 1/2), then this is straight forward:
...
float gradient_scale = 0.5;
for (int step = 0; step < N; ++step) {
loss_output = LossFnc(...);
loss_output.register_hook([&](torch::Tensor grad) {
return grad * gradient_scale;
});
}
Suppose now we want to scale each step independently. Let gradient_scale
be a [batch_size, N] tensor where N is the number of unroll steps. Thus, for the i
th unroll step, we would want to scale the gradient by gradient_scale[:,i]
. How could one achieve this? I tried the following:
...
torch::Tensor gradient_scale; // [batch_size, N]
for (int step = 0; step < N; ++step) {
loss_output = LossFnc(...);
loss_output.register_hook([&](torch::Tensor grad) {
return grad * gradient_scale.index({Slice(), step});
});
}
The issue with this is that when these hooks are called, the value of step
is now N
, which every hook will reference. This thus throws an index error, since accessing the N
th column of gradient_scale
is now out of bounds.
Any help is appreciated!