Hi All,
I have a bit of a technical question regarding the register_hook
function. The question I want to ask is how can I remove the reduction that’s done within register_hook
function. I’m not sure if this is the correct terminology so please correct me if I’m wrong.
For example, the code below takes an input, IN, of size 1 with a batch of size BATCH. It passes it through a weight matrix to create an output OUT. The Squared Error (not MSE) is then calculated with Tensor Y and the gradient is calculated for all inputs.
import torch
BATCH=10
IN=1
NHID=5
OUT=2
X = torch.randn(BATCH, IN)
W = torch.randn(IN, NHID, requires_grad=True)
W.register_hook(lambda x: print("W: ",x))
Y = torch.randn(BATCH, NHID)
OUT=X.mm(W)
loss = (OUT-Y)**2
loss.register_hook(lambda x: print("loss: ",x))
loss.backward(torch.ones_like(loss))
This prints out the following
loss: tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
W: tensor([[ 4.9739, 0.1722, 0.9640, -12.8224, 5.5129]])
From my understanding, the register_hook
function calculates the gradient of the loss w.r.t to the Tensor to which it is registered. However, the gradient vector that is shown is of the same dimensionality of W
and does not hold the batch dimension. So, it must be summed over or reduced somehow. Is it possible to disable this feature somehow so the hook returns a Tensor of (BATCH, IN, HID) ?
Thanks in advance!